Unverified Commit 599258f9 authored by Samyam Rajbhandari's avatar Samyam Rajbhandari Committed by GitHub
Browse files

ZeRO 3 Offload (#834)



* Squash stage3 v1 (#146)
Co-authored-by: default avatarSamyam <samyamr@microsoft.com>
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
Co-authored-by: default avatarSamyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: default avatarOlatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: default avatarShaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: default avatarShaden Smith <ShadenTSmith@gmail.com>
Co-authored-by: default avatareltonzheng <eltonz@microsoft.com>

* Fix correctness bug (#147)

* formatting fix (#150)

* stage3 bugfix (API) update and simplified FP16 Z3 tests (#151)

* fp16 Z3 API update and bugfix

* revert debug change

* ZeRO-3 detach and race condition bugfixes (#149)

* trying out ZeRO-3 race condition fix

* CUDA sync instead of stream

* reduction stream sync

* remove commented code

* Fix optimizer state_dict KeyError (#148)
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>

* fix for smaller SGS sizes, ensures each grad is backed by unique tensors (#152)

* Simplifying the logic for getting averaged gradients (#153)

* skip for now

* Z3 Docs redux (#154)

* removing some TODOs and commented code (#155)

* New Z3 defaults (#156)
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>

* formatting

* megatron external params
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
Co-authored-by: default avatarOlatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: default avatarShaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: default avatarShaden Smith <ShadenTSmith@gmail.com>
Co-authored-by: default avatareltonzheng <eltonz@microsoft.com>
parent ba33e86e
......@@ -48,4 +48,4 @@ jobs:
- name: Unit tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose tests/unit/
......@@ -16,6 +16,8 @@ from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConf
from .utils import log_dist
from .utils.distributed import init_distributed
from .runtime import zero
from .pipe import PipelineModule
from .git_version_info import version, git_hash, git_branch
......
......@@ -304,7 +304,7 @@ def main(args=None):
# encode world info as base64 to make it easier to pass via command line
world_info_base64 = encode_world_info(active_resources)
multi_node_exec = len(active_resources) > 1
multi_node_exec = True # len(active_resources) > 1
if multi_node_exec and not shutil.which('pdsh'):
raise RuntimeError("pdsh is not installed, unable to proceed")
......
......@@ -10,41 +10,6 @@ from ..op_builder import CPUAdamBuilder
class DeepSpeedCPUAdam(torch.optim.Optimizer):
"""Fast vectorized implementation of two variations of Adam optimizer on CPU:
- Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980);
- AdamW: FIXING WEIGHT DECAY REGULARIZATION IN ADAM (https://arxiv.org/abs/1711.05101v1)
DeepSpeed CPU Adam(W) provides between 5x to 7x speedu over torch.optim.adam(W).
In order to apply this optimizer, the model requires to have its master parameter (in FP32)
reside on the CPU memory.
To train on a hetrogeneous system, such as coordinating CPU and GPU, DeepSpeed offers
the ZeRO-Offload technology which efficiently offloads the optimizer states into CPU memory,
with minimal impact on training througput. DeepSpeedCPUAdam plays an important role to minimize
the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial
(https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology.
For calling step function, there are two options available: (1) update optimizer's states and (2) update
optimizer's states and copy the parameters back to GPU at the same time. We have seen that the second
option can bring 30% higher throughput than the doing the copy separately using option one.
Arguments:
model_params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in DeepSpeed CPUAdam!
adamw_mode: select between Adam and AdamW implementations (default: AdamW)
"""
optimizer_id = 0
def __init__(self,
......@@ -57,6 +22,47 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
weight_decay=0,
amsgrad=False,
adamw_mode=True):
"""Fast vectorized implementation of two variations of Adam optimizer on CPU:
* Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980);
* AdamW: Fixing Weight Decay Regularization in Adam (https://arxiv.org/abs/1711.05101)
DeepSpeed CPU Adam(W) provides between 5x to 7x speedup over torch.optim.adam(W).
In order to apply this optimizer, the model requires to have its master parameter (in FP32)
reside on the CPU memory.
To train on a hetrogeneous system, such as coordinating CPU and GPU, DeepSpeed offers
the ZeRO-Offload technology which efficiently offloads the optimizer states into CPU memory,
with minimal impact on training througput. DeepSpeedCPUAdam plays an important role to minimize
the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial
(https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology.
For calling step function, there are two options available: (1) update optimizer's states and (2) update
optimizer's states and copy the parameters back to GPU at the same time. We have seen that the second
option can bring 30% higher throughput than the doing the copy separately using option one.
.. note::
We recommend using our `config
<https://www.deepspeed.ai/docs/config-json/#optimizer-parameters>`_
to allow :meth:`deepspeed.initialize` to build this optimizer
for you.
Arguments:
model_params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in DeepSpeed CPUAdam!
adamw_mode: select between Adam and AdamW implementations (default: AdamW)
"""
default_args = dict(lr=lr,
betas=betas,
......@@ -86,6 +92,24 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
@torch.no_grad()
def step(self, closure=None, fp16_param_groups=None):
"""Update the model parameters.
.. note::
This method will be called internally by ZeRO-Offload. DeepSpeed
users should still use ``engine.step()`` as shown in the
`Getting Started
<https://www.deepspeed.ai/getting-started/#training>`_ guide.
Args:
closure (callable, optional): closure to compute the loss.
Defaults to ``None``.
fp16_param_groups: FP16 GPU parameters to update. Performing the
copy here reduces communication time. Defaults to ``None``.
Returns:
loss: if ``closure`` is provided. Otherwise ``None``.
"""
loss = None
if closure is not None:
with torch.enable_grad():
......@@ -100,7 +124,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
state = self.state[p]
# State initialization
if len(state) == 0:
print(f'group {group_id} param {param_id} = {p.numel()}')
#print(f'group {group_id} param {param_id} = {p.numel()}')
state['step'] = 0
# gradient momentums
state['exp_avg'] = torch.zeros_like(p.data,
......
......@@ -18,6 +18,7 @@ import torch
import contextlib
import torch.distributed as dist
import mmap
from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
......@@ -26,19 +27,19 @@ from deepspeed.utils import logger
from deepspeed.runtime.utils import move_to_device
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers
#DeepSpeed Checkpointing Enabled or Disabled
# DeepSpeed Checkpointing Enabled or Disabled
deepspeed_checkpointing_enabled = False
#MP parameters
# MP parameters
mpu = None
mp_rank = None
mp_size = None
mp_group = None
#Model Parameters
# Model Parameters
num_layers = None
#Checkpointing buffers
# Checkpointing buffers
contiguous_data_buffers = []
data_offsets = []
......@@ -47,7 +48,7 @@ size_offsets = []
timers = None
#optimization flags
# optimization flags
PARTITION_ACTIVATIONS = False
PA_TO_CPU = False
CONTIGUOUS_CHECKPOINTING = False
......@@ -56,10 +57,10 @@ PROFILE_TIME = False
def see_memory_usage(message, force=False):
#return
# return
if not force:
return
#dist.barrier()
# dist.barrier()
if dist.get_rank() == 0:
logger.info(message)
logger.info(
......@@ -78,6 +79,7 @@ def see_memory_usage(message, force=False):
"Max cache Allocated %s GigaBytes",
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
)
logger.info("")
#input("Press Any Key To Continue ..")
......@@ -348,7 +350,22 @@ def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags):
tensor_idx = 0
non_tensor_idx = 0
for is_tensor in tensor_flags:
real_tensor_flags = None
#remove the flags that are assigned to the size of the flattened tensors
if PARTITION_ACTIVATIONS:
real_tensor_flags = []
previous_flag = False
for flag in tensor_flags:
if previous_flag:
previous_flag = False
continue
previous_flag = flag
real_tensor_flags.append(flag)
else:
real_tensor_flags = tensor_flags
for is_tensor in real_tensor_flags:
if is_tensor:
merged_objects.append(tensor_objects[tensor_idx])
tensor_idx += 1
......@@ -406,7 +423,7 @@ class CheckpointFunction(torch.autograd.Function):
global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset
if cuda_device is None:
see_memory_usage("First Forward Begining", force=True)
see_memory_usage("First Forward Begining", force=False)
if dist.get_rank() == 0:
logger.info(f"Activation Checkpointing Information")
logger.info(
......@@ -423,7 +440,7 @@ class CheckpointFunction(torch.autograd.Function):
if PARTITION_ACTIVATIONS:
#inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]]
#inputs.append(args[-1])
# inputs.append(args[-1])
inputs = []
for i, item in enumerate(args[:-1]):
......@@ -460,6 +477,19 @@ class CheckpointFunction(torch.autograd.Function):
contiguous_data_buffers[i] = tensor_list
data_offsets[i] = 0
# Because the 'new_empty' returns uninitialized pages,
# the pages need to be populated during the cudaMemcpy time
# which increases the data copy time. To avoid this, we
# pre-populate these pages by simply writing 0 ahead of
# the actual cudaMemcpy operation time. Due to the
# previously launched GPU kernels, there is a small
# window of time here for CPUs to populate pages asynchronously.
contiguous_data_buffers[i][data_offsets[i]].data[range(
0,
contiguous_data_buffers[i][data_offsets[i]].data.shape[0],
int(mmap.PAGESIZE / contiguous_data_buffers[i][
data_offsets[i]].data.element_size()))] = 0
contiguous_partition = contiguous_data_buffers[i][
data_offsets[i]].data.copy_(partition.data)
data_offsets[i] = data_offsets[i] + 1
......@@ -478,14 +508,16 @@ class CheckpointFunction(torch.autograd.Function):
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
#ctx.save_for_backward(*args)
see_memory_usage("Before running forward on the layer", force=False)
# ctx.save_for_backward(*args)
with torch.no_grad():
outputs = run_function(*inputs_cuda)
see_memory_usage("After running forward on the layer", force=False)
del inputs_cuda
#with torch.cuda.stream(transport_stream):
#if PARTITION_ACTIVATIONS:
# with torch.cuda.stream(transport_stream):
# if PARTITION_ACTIVATIONS:
# new_args = []
# for arg, inp in zip(args,inputs):
# size= torch.tensor(arg.size())
......@@ -531,7 +563,7 @@ class CheckpointFunction(torch.autograd.Function):
new_args.append(contiguous_size)
else:
new_args.append(size)
#if dist.get_rank() == 0:
# if dist.get_rank() == 0:
# logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ")
save_args_for_backward(*new_args)
......@@ -564,10 +596,10 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *grads):
global timers
#see_memory_usage("In backward", force=True)
#removing pointers to the contiguous buffer memory
#so that they can be garbage collected once the checkpoints
#have been used
see_memory_usage("In backward", force=False)
# removing pointers to the contiguous buffer memory
# so that they can be garbage collected once the checkpoints
# have been used
if SYNCHRONIZE:
torch.cuda.synchronize()
if PROFILE_TIME:
......@@ -580,14 +612,14 @@ class CheckpointFunction(torch.autograd.Function):
for buffers in contiguous_data_buffers:
buffers = []
#frees up all the pointers to the checkpoints except for the ones
#stored by save for backward
# frees up all the pointers to the checkpoints except for the ones
# stored by save for backward
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
size_offsets = []
#see_memory_usage("In backward checkpointing code", force=True)
see_memory_usage("In backward checkpointing code", force=False)
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
......@@ -595,7 +627,7 @@ class CheckpointFunction(torch.autograd.Function):
global cuda_device, transport_stream, PARTITION_ACTIVATIONS
if PARTITION_ACTIVATIONS:
#with torch.cuda.stream(transport_stream):
# with torch.cuda.stream(transport_stream):
inputs = get_full_inputs(ctx.saved_tensors,
device=cuda_device if PA_TO_CPU else None)
detached_inputs = detach_variable(inputs)
......@@ -622,9 +654,12 @@ class CheckpointFunction(torch.autograd.Function):
# current_stream=torch.cuda.current_stream()
# current_stream.wait_stream(transport_stream)
see_memory_usage("In backward checkpointing code before forward", force=False)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
see_memory_usage("In backward checkpointing code after forward", force=False)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
......@@ -646,8 +681,13 @@ class CheckpointFunction(torch.autograd.Function):
output_tensors.append(out)
grad_tensors.append(grad)
see_memory_usage("In backward checkpointing code before backward", force=False)
torch.autograd.backward(output_tensors, grad_tensors)
see_memory_usage("After backward checkpointing code before backward",
force=False)
if PROFILE_TIME:
timers('backward').stop()
timers.log(['backward'])
......@@ -706,8 +746,8 @@ def reset():
for buffers in contiguous_data_buffers:
buffers = []
#frees up all the pointers to the checkpoints except for the ones
#stored by save for backward
# frees up all the pointers to the checkpoints except for the ones
# stored by save for backward
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
......@@ -716,10 +756,11 @@ def reset():
def _configure_using_config_file(deepspeed_config, mpu=None):
global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
config = DeepSpeedConfig(deepspeed_config, mpu=mpu).activation_checkpointing_config
logger.info(config.repr())
if dist.get_rank() == 0:
logger.info(config.repr())
PARTITION_ACTIVATIONS = config.partition_activations
CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization
num_layers = config.number_checkpoints
......@@ -733,7 +774,7 @@ def _configure_defaults():
global mpu, num_layers, deepspeed_checkpointing_enabled
global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
PARTITION_ACTIVATIONS = False
CONTIGUOUS_CHECKPOINTING = False
......@@ -792,7 +833,7 @@ def configure(
global mpu, num_layers, deepspeed_checkpointing_enabled
global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
_configure_defaults()
......
......@@ -752,9 +752,9 @@ class DeepSpeedConfig(object):
if self.zero_enabled:
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
if self.zero_config.cpu_offload is True:
assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS)
#assert self.gradient_accumulation_steps == 1, "DeepSpeedConfig: {}is not supported for {}".format(GRADIENT_ACCUMULATION_STEPS, ZERO_OPTIMIZATION_CPU_OFFLOAD)
#if self.zero_config.cpu_offload is True:
# assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS)
#assert self.gradient_accumulation_steps == 1, "DeepSpeedConfig: {}is not supported for {}".format(GRADIENT_ACCUMULATION_STEPS, ZERO_OPTIMIZATION_CPU_OFFLOAD)
def _do_warning_check(self):
fp16_enabled = self.fp16_enabled or self.zero_enabled
......
......@@ -12,8 +12,10 @@ from torch.nn.modules import Module
from torch.distributed.distributed_c10d import _get_global_rank
from tensorboardX import SummaryWriter
from deepspeed.runtime.utils import see_memory_usage
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
......@@ -27,7 +29,7 @@ from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
PLD_THETA, PLD_GAMMA
from deepspeed.runtime.zero.constants import \
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS, ZERO_OPTIMIZATION_WEIGHTS
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.utils import logger, log_dist, init_distributed
......@@ -105,8 +107,10 @@ class DeepSpeedEngine(Module):
mpu=None,
dist_init_required=None,
collate_fn=None,
config_params=None):
config_params=None,
dont_change_device=False):
super(DeepSpeedEngine, self).__init__()
self.dont_change_device = dont_change_device
self.client_optimizer = optimizer
self.client_model_parameters = model_parameters
self.client_lr_scheduler = lr_scheduler
......@@ -136,6 +140,7 @@ class DeepSpeedEngine(Module):
# Initialize torch distributed if needed
init_distributed(dist_backend=self.dist_backend)
see_memory_usage(f"DeepSpeed Engine: Before args sanity test")
self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
......@@ -149,9 +154,13 @@ class DeepSpeedEngine(Module):
if self.tensorboard_enabled() and self.global_rank == 0:
self.summary_writer = self.get_summary_writer()
see_memory_usage(f"DeepSpeed Engine: Before configure distributed model")
# Configure distributed model
self._configure_distributed_model(model)
see_memory_usage(f"DeepSpeed Engine: After configure distributed model")
# Configure wall clock timer
self.timers = SynchronizedWallClockTimer()
......@@ -331,6 +340,15 @@ class DeepSpeedEngine(Module):
def zero_cpu_offload(self):
return self._config.zero_config.cpu_offload
def zero_cpu_offload_params(self):
return self._config.zero_config.cpu_offload_params
def zero_cpu_offload_use_pin_memory(self):
return self._config.zero_config.cpu_offload_use_pin_memory
def zero_sub_group_size(self):
return self._config.zero_config.sub_group_size
def zero_optimization_stage(self):
return self._config.zero_optimization_stage
......@@ -343,6 +361,9 @@ class DeepSpeedEngine(Module):
def zero_optimization_partition_gradients(self):
return self.zero_optimization_stage() >= ZERO_OPTIMIZATION_GRADIENTS
def zero_optimization_partition_weights(self):
return self.zero_optimization_stage() >= ZERO_OPTIMIZATION_WEIGHTS
def zero_contiguous_gradients(self):
return self._config.zero_config.contiguous_gradients
......@@ -352,6 +373,18 @@ class DeepSpeedEngine(Module):
def zero_elastic_checkpoint(self):
return self._config.zero_config.elastic_checkpoint
def zero_max_live_parameters(self):
return self._config.zero_config.max_live_parameters
def zero_max_reuse_distance(self):
return self._config.zero_config.max_reuse_distance
def zero_prefetch_bucket_size(self):
return self._config.zero_config.prefetch_bucket_size
def zero_param_persistence_threshold(self):
return self._config.zero_config.param_persistence_threshold
def fp16_enabled(self):
return self._config.fp16_enabled
......@@ -418,7 +451,8 @@ class DeepSpeedEngine(Module):
dp_rank = self.mpu.get_data_parallel_rank()
# only the first data parallel process needs to store the model checkpoint
self.save_non_zero_checkpoint = (dp_rank == 0)
self.save_non_zero_checkpoint = (
dp_rank == 0) or self.zero_optimization_partition_weights()
if self.zero_optimization():
param_rank = torch.distributed.get_rank(
......@@ -512,8 +546,13 @@ class DeepSpeedEngine(Module):
'DeepSpeed {} optimizer requires dynamic loss scaling'.format(self.optimizer_name())
def _broadcast_model(self):
def is_replicated(p):
if hasattr(p, 'ds_status') and p.ds_status is not ZeroParamStatus.AVAILABLE:
return False
return True
for p in self.module.parameters():
if torch.is_tensor(p):
if torch.is_tensor(p) and is_replicated(p):
dist.broadcast(p,
self.broadcast_src_rank,
group=self.data_parallel_group)
......@@ -522,7 +561,9 @@ class DeepSpeedEngine(Module):
self.module = model
if self.fp16_enabled():
self.module.half()
self.module.to(self.device)
if not self.dont_change_device:
self.module.to(self.device)
if self.mpu is None:
self.data_parallel_group = _initialize_parameter_parallel_groups()
......@@ -555,7 +596,8 @@ class DeepSpeedEngine(Module):
self.optimizer_name()))
if self.global_rank == 0:
logger.info('DeepSpeed Basic Optimizer = {}'.format(basic_optimizer))
logger.info('DeepSpeed Basic Optimizer = {}'.format(
basic_optimizer.__class__.__name__))
if self.zero_optimization():
assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
......@@ -585,7 +627,8 @@ class DeepSpeedEngine(Module):
self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
else:
self.optimizer = basic_optimizer
logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer))
log_dist('DeepSpeed Final Optimizer = {}'.format(self.optimizer_name()),
ranks=[0])
def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params()
......@@ -636,7 +679,7 @@ class DeepSpeedEngine(Module):
if isinstance(optimizer,
FusedAdam) or self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
if self.dynamic_loss_scale():
logger.info('Creating fp16 optimizer with dynamic loss scale')
log_dist('Creating fp16 optimizer with dynamic loss scale', ranks=[0])
timers = self.timers if self.wall_clock_breakdown() else None
optimizer = FP16_Optimizer(
optimizer,
......@@ -648,8 +691,9 @@ class DeepSpeedEngine(Module):
fused_adam_legacy=self.optimizer_legacy_fusion(),
timers=timers)
else:
logger.info('Creating fp16 optimizer with static loss scale: {}'.format(
self.loss_scale()))
log_dist('Creating fp16 optimizer with static loss scale: {}'.format(
self.loss_scale()),
ranks=[0])
optimizer = FP16_Optimizer(
optimizer,
static_loss_scale=self.loss_scale(),
......@@ -657,7 +701,8 @@ class DeepSpeedEngine(Module):
clip_grad=clip_grad,
fused_adam_legacy=self.optimizer_legacy_fusion())
else:
logger.info('Creating fp16 unfused optimizer with dynamic loss scale')
log_dist('Creating fp16 unfused optimizer with dynamic loss scale',
ranks=[0])
optimizer = FP16_UnfusedOptimizer(
optimizer,
static_loss_scale=self.loss_scale(),
......@@ -671,8 +716,9 @@ class DeepSpeedEngine(Module):
def _configure_zero_optimizer(self, optimizer):
zero_stage = self.zero_optimization_stage()
logger.info('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage))
log_dist('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage), ranks=[0])
assert not self.allreduce_always_fp32(), "ZeRO does not support 'fp32_allreduce': true"
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode'
optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
......@@ -706,6 +752,35 @@ class DeepSpeedEngine(Module):
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps())
elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS:
print("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
optimizer = FP16_DeepSpeedZeroOptimizer_Stage3(
self.module,
optimizer,
timers=self.timers,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=self.zero_contiguous_gradients(),
reduce_bucket_size=self.zero_reduce_bucket_size(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
cpu_offload_optimizer_state=self.zero_cpu_offload(),
cpu_offload_params=self.zero_cpu_offload_params(),
cpu_offload_use_pin_memory=self.zero_cpu_offload_use_pin_memory(),
sub_group_size=self.zero_sub_group_size(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps())
else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
......@@ -817,6 +892,11 @@ class DeepSpeedEngine(Module):
self.tput_timer.start()
loss = self.module(*inputs, **kwargs)
# Reset the ZeRO-3 state if we are only doing forward-passes (ie evaluation).
if self.zero_optimization_partition_weights():
if not torch._C.is_grad_enabled():
self.optimizer.param_coordinator.reset_step()
if self.wall_clock_breakdown():
self.timers('forward').stop()
self.timers('forward_microstep').stop()
......@@ -1267,9 +1347,18 @@ class DeepSpeedEngine(Module):
def _get_ckpt_name(self, checkpoints_path, tag):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
ckpt_name = os.path.join(checkpoints_path,
str(tag),
'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')
if self.zero_optimization_partition_weights():
filename = 'zero_pp_rank_{}'.format(
torch.distributed.get_rank(group=self.optimizer.dp_process_group))
ckpt_name = os.path.join(
checkpoints_path,
str(tag),
filename + '_mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')
else:
ckpt_name = os.path.join(
checkpoints_path,
str(tag),
'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')
return ckpt_name
def load_checkpoint(self,
......@@ -1478,6 +1567,10 @@ class DeepSpeedEngine(Module):
process with rank 0.
"""
if self.zero_optimization_partition_weights():
# Prepare for state_dict() by ensuring all parameters are partitioned
self.optimizer.save_checkpoint_prologue()
# This is to make sure the checkpoint names are created without collision
# There seems to be issue creating them in parallel
......@@ -1506,6 +1599,9 @@ class DeepSpeedEngine(Module):
with open(os.path.join(save_dir, 'latest'), 'w') as fd:
fd.write(tag)
if self.zero_optimization_partition_weights():
self.optimizer.save_checkpoint_epilogue()
return True
def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):
......
......@@ -7,6 +7,7 @@ Helper functions and classes from multiple sources.
'''
import os
import psutil
from math import ceil
from math import floor
from bisect import bisect_left, bisect_right
......@@ -72,7 +73,7 @@ class CheckOverflow(object):
self.params.append(param)
def check_using_norm(self, norm_group, reduce_overflow=True):
#TODO: I don't think reduce_overflow is needed if mpu is None
# TODO: I don't think reduce_overflow is needed if mpu is None
overflow = -1 in norm_group
if self.mpu is not None:
......@@ -115,7 +116,7 @@ class CheckOverflow(object):
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
overflow_gpu = torch.cuda.ByteTensor([overflow])
#torch.distributed.all_reduce(overflow_gpu,
# torch.distributed.all_reduce(overflow_gpu,
# op=torch.distributed.ReduceOp.MAX,
# group=mpu.get_model_parallel_group())
if self.zero_reduce_scatter:
......@@ -544,8 +545,9 @@ def memory_status(msg, print_rank=-1, reset_max=False):
)
def see_memory_usage(message):
return
def see_memory_usage(message, force=False):
if not force:
return
if torch.distributed.is_initialized() and not torch.distributed.get_rank() == 0:
return
......@@ -557,6 +559,11 @@ def see_memory_usage(message):
CA {round(torch.cuda.memory_cached() / (1024 * 1024 * 1024),2)} GB \
Max_CA {round(torch.cuda.max_memory_cached() / (1024 * 1024 * 1024))} GB ")
vm_stats = psutil.virtual_memory()
used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2)
logger.info(
f'CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%')
def call_to_str(base, *args, **kwargs):
"""Construct a string representation of a call.
......
from .partition_parameters import ZeroParamType
from .partition_parameters import ZeroParamStatus
from .partition_parameters import Init
from .partition_parameters import GatheredParameters
from .partition_parameters import register_external_parameter
......@@ -21,9 +21,27 @@ class DeepSpeedZeroConfig(object):
self.allgather_bucket_size = None
self.overlap_comm = None
self.load_from_fp32_weights = None
self.cpu_offload = None
self.elastic_checkpoint = None
#Offload Specific Parameters
self.cpu_offload = None
self.cpu_offload_params = None
self.cpu_offload_use_pin_memory = None
self.sub_group_size = None
#Stage3 Specific Parameters
self.prefetch_bucket_size = None
self.param_persistence_threshold = None
self.max_live_parameters = None
self.max_reuse_distance = None
#Stage3 Specific Parameters
self.prefetch_bucket_size = None
self.param_persistence_threshold = None
self.max_live_parameters = None
self.max_reuse_distance = None
if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION]
if type(zero_config_dict) is bool:
......@@ -66,6 +84,8 @@ class DeepSpeedZeroConfig(object):
self.contiguous_gradients = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS,
ZERO3_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT
if self.stage == ZERO_OPTIMIZATION_WEIGHTS else
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT)
self.reduce_bucket_size = get_scalar_param(
......@@ -77,9 +97,12 @@ class DeepSpeedZeroConfig(object):
ZERO_OPTIMIZATION_REDUCE_SCATTER,
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT)
self.overlap_comm = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_OVERLAP_COMM,
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT)
self.overlap_comm = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_OVERLAP_COMM,
ZERO3_OPTIMIZATION_OVERLAP_COMM_DEFAULT
if self.stage == ZERO_OPTIMIZATION_WEIGHTS else
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT)
self.allgather_partitions = get_scalar_param(
zero_config_dict,
......@@ -104,3 +127,37 @@ class DeepSpeedZeroConfig(object):
zero_config_dict,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT)
self.cpu_offload_params = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS,
ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS_DEFAULT)
self.cpu_offload_use_pin_memory = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_CPU_OFFLOAD_USE_PIN_MEMORY,
ZERO_OPTIMIZATION_CPU_OFFLOAD_USE_PIN_MEMORY_DEFAULT)
self.sub_group_size = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_SUB_GROUP_SIZE,
ZERO_OPTIMIZATION_SUB_GROUP_SIZE_DEFAULT)
self.max_live_parameters = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS,
ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS_DEFAULT)
self.max_reuse_distance = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE,
ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE_DEFAULT)
self.prefetch_bucket_size = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE,
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT)
self.param_persistence_threshold = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT)
......@@ -13,14 +13,19 @@ ZeRO optimization should be enabled as:
"session_params": {
"zero_optimization": {
"stage": [0|1|2],
"stage3_max_live_parameters" : 1000000000,
"stage3_max_reuse_distance" : 1000000000,
"allgather_partitions": [true|false],
"allgather_bucket_size": 500000000,
"reduce_scatter": [true|false],
"contiguous_gradients" : [true|false]
"overlap_comm": [true|false],
"reduce_bucket_size": 500000000
"load_from_fp32_weights": [true|false]
"cpu_offload": [true|false]
"reduce_bucket_size": 500000000,
"load_from_fp32_weights": [true|false],
"cpu_offload": [true|false],
"cpu_offload_params" : [true|false],
"cpu_offload_use_pin_memory": [true|false],
"sub_group_size" : 1000000000000
}
}
'''
......@@ -30,7 +35,7 @@ ZERO_OPTIMIZATION_DISABLED = 0
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
ZERO_OPTIMIZATION_GRADIENTS = 2
ZERO_OPTIMIZATION_WEIGHTS = 3
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_WEIGHTS
ZERO_OPTIMIZATION_STAGE = 'stage'
ZERO_OPTIMIZATION_STAGE_1 = 'stage_1'
......@@ -47,9 +52,11 @@ ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT = True
ZERO_OPTIMIZATION_OVERLAP_COMM = 'overlap_comm'
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT = False
ZERO3_OPTIMIZATION_OVERLAP_COMM_DEFAULT = True
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS = 'contiguous_gradients'
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT = False
ZERO3_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT = False
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE = 'reduce_bucket_size'
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT = 500000000
......@@ -66,18 +73,65 @@ ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT = False
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT = 'elastic_checkpoint'
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT = True
ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS = 'cpu_offload_params'
ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS_DEFAULT = False
ZERO_OPTIMIZATION_CPU_OFFLOAD_USE_PIN_MEMORY = 'cpu_offload_use_pin_memory'
ZERO_OPTIMIZATION_CPU_OFFLOAD_USE_PIN_MEMORY_DEFAULT = False
ZERO_OPTIMIZATION_SUB_GROUP_SIZE = 'sub_group_size'
ZERO_OPTIMIZATION_SUB_GROUP_SIZE_DEFAULT = 1000000000000
#maximum number of parameters per GPU before releasing them
ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS = 'stage3_max_live_parameters'
ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS_DEFAULT = 1000000000
#release a parameter only if the reuse distance is larger than specified
ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE = 'stage3_max_reuse_distance'
ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE_DEFAULT = 1000000000
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE = 'stage3_prefetch_bucket_size'
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT = 50000000
#parameters smaller than the threshold are only communicated once after the
#parameters are updated and are persisted thoughout the trainging
#avoid tons of latency bound communication
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD = 'stage3_param_persistence_threshold'
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT = 100000
ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_STAGE:
ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER:
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE:
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT,
ZERO_OPTIMIZATION_CPU_OFFLOAD: ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT: ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT
ZERO_OPTIMIZATION_CPU_OFFLOAD:
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT:
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT,
ZERO_OPTIMIZATION_CPU_OFFLOAD:
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT,
ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS:
ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS_DEFAULT,
ZERO_OPTIMIZATION_CPU_OFFLOAD_USE_PIN_MEMORY:
ZERO_OPTIMIZATION_CPU_OFFLOAD_USE_PIN_MEMORY,
ZERO_OPTIMIZATION_SUB_GROUP_SIZE:
ZERO_OPTIMIZATION_SUB_GROUP_SIZE_DEFAULT,
ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS:
ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS_DEFAULT,
ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE:
ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE_DEFAULT,
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE:
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD:
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT
}
import torch
def print_rank_0(message):
if torch.distributed.get_rank() == 0:
print(message)
class ContiguousMemoryAllocator(object):
def __init__(self, size, dtype, device):
self.buffer = torch.zeros(size, dtype=dtype, device=device)
#address to contiguous size avaialble
self.contiguous_sizes = {}
self.contiguous_sizes[0] = size
#tensor id to its address
self.tensor_addresses = {}
#tensor address to its size
self.tensor_sizes = {}
#tensor address to ids
self.tensor_ids = {}
#id to tensors
self.tensor_map = {}
#id to params. Maps each tensor buffer to list of parameters that uses it
self.id_to_params = {}
self.total_size = size
self.total_free = size
self.largest_contiguous = size
self.max_allocated = 0
self.count = 0
#create a tensor of size from the pre-allocated buffer
#if not enough free space will fail
#if not enough contiguous space, will defragment and allocate
def allocate_tensor(self, size):
free_before = self.total_free
assert size <= self.total_free, "Not enough memory in buffer. Allocation failed"
if self.largest_contiguous < size:
print_rank_0("Needs defragmentation to allocate. Before Defragmentation:")
self.print_allocation(resolution=100)
self._defragment_memory()
#set the param data to the new tensor buffer locations
self._reset_param_data()
print_rank_0("After defragmentation:")
self.print_allocation(resolution=100)
self.total_free = self.total_free - size
allocated = self.total_size - self.total_free
if allocated > self.max_allocated:
self.max_allocated = allocated
tensor_address = self._get_new_tensor_address(size)
ret_tensor = self._get_new_tensor(tensor_address, size)
print_rank_0(
f"Free before allocation {free_before}. Allocating {size}. Free after allocation {self.total_free}. Max allocated {self.max_allocated}"
)
assert self.total_free + size == free_before, "Allcation bookeeping error"
return ret_tensor
#assigns the tensor data to the param data and keeps track of the assignment
#any change the the underlying buffer from defragmentation will cause a
#reassignment of the param data
def assign_to_param(self, tensor, param, numel, shape):
tensor_id = id(tensor)
assert tensor_id in self.tensor_map.keys(), "No such tensor allocated by the allocator."
assert tensor.numel() >= numel, "Assert tensor buffer does is not large enough"
assert not tensor_id in self.id_to_params.keys(), "This tensor has already been assigned to a param"
self.id_to_params[tensor_id] = [param]
replicated_tensor = tensor.narrow(0, 0, numel).view(shape)
param.data = replicated_tensor.data
param.contiguous_tensor_id = tensor_id
#deletes the tensor and frees up the underlying buffer
def release_tensor(self, tensor):
free_before = self.total_free
tensor_id = id(tensor)
tensor_size = tensor.numel()
self._release_tensor(tensor_id)
self._unassign_params(tensor_id)
self.total_free += tensor_size
print_rank_0(
f"Free before release {free_before}. Released {tensor.numel()}. Total free after {self.total_free}."
)
assert self.total_free - tensor_size == free_before, "Release bookeeping error"
def release_tensor_with_id(self, tensor_id):
free_before = self.total_free
assert tensor_id in self.tensor_map.keys(), "Invalid tensor id"
tensor = self.tensor_map[tensor_id]
tensor_size = tensor.numel()
self._release_tensor(tensor_id)
self._unassign_params(tensor_id)
self.total_free += tensor_size
print_rank_0(
f"Free before release {free_before}. Released {tensor.numel()}. Total free after {self.total_free}."
)
assert self.total_free - tensor_size == free_before, "Release bookeeping error"
#shows the current memory allocation at specified resolution
def print_allocation(self, resolution=200):
total_size = self.buffer.numel() * 1.0
empty = []
for addr, size in self.contiguous_sizes.items():
start = int(addr * resolution / total_size)
end = int((addr + size) * resolution / total_size)
empty.extend(range(start, end))
s = ''
for i in range(resolution):
s += '.' if i in empty else '|'
print_rank_0(s)
def max_allocated(self):
return self.max_allocated
#to be called after defragmentation that moves the tensor buffers
#this call reassigns the data of all the parameters using the tensor buffers
def _reset_param_data(self):
for id, tensor in self.tensor_map.items():
for param in self.id_to_params[id]:
param.data = tensor.narrow(0,
0,
param.numel()).view(param.data.shape).data
def _unassign_params(self, tensor_id):
if tensor_id in self.id_to_params.keys():
del self.id_to_params[tensor_id]
def _release_tensor(self, tensor_id):
assert tensor_id in self.tensor_addresses, f"Tensor id {tensor_id} not found"
address = self.tensor_addresses[tensor_id]
contiguous_size = self.tensor_map[tensor_id].numel()
del self.tensor_addresses[tensor_id]
del self.tensor_ids[address]
del self.tensor_map[tensor_id]
del self.tensor_sizes[address]
self._consolidate_address(address, contiguous_size)
self.largest_contiguous = self._largest_contiguous()
def _consolidate_address(self, address, contiguous_size):
#consolidate next buffer
end_address = address + contiguous_size
if end_address in self.contiguous_sizes:
contiguous_size += self.contiguous_sizes[end_address]
del self.contiguous_sizes[end_address]
#consolidate previous buffer
for addr, size in self.contiguous_sizes.items():
if addr + size == address:
del self.contiguous_sizes[addr]
contiguous_size += size
address = addr
break
self.contiguous_sizes[address] = contiguous_size
def _defragment_memory(self):
empty_addresses = sorted(self.contiguous_sizes.keys())
tensor_addresses = sorted(self.tensor_addresses.values())
tensor_index = 0
while tensor_index < len(tensor_addresses):
empty_addr = empty_addresses[0]
empty_size = self.contiguous_sizes[empty_addr]
tensor_addr = tensor_addresses[tensor_index]
tensor_size = self.tensor_sizes[tensor_addr]
tensor_id = self.tensor_ids[tensor_addr]
tensor = self.tensor_map[self.tensor_ids[tensor_addr]]
assert tensor_size == tensor.numel(), \
"Size mismatch. {tensor_size} is allocated at addr {tensor_addr} but tensor size is {tensor.numel()} "
assert empty_addr != tensor_addr, \
f"Cannot have same empty address {empty_addr} and tensor address {tensor_addr}"
if empty_addr < tensor_addr:
if empty_size >= tensor_size:
dest_buffer = self.buffer.narrow(0, empty_addr, tensor_size)
src_buffer = self.buffer.narrow(0, tensor_addr, tensor_size)
dest_buffer.data.copy_(src_buffer.data)
else:
#print_rank_0(f'empty addr : {empty_addr}, empty size {empty_size} tensor addr {tensor_addr} tensor size {tensor_size}')
src_addr = tensor_addr
dest_addr = empty_addr
while src_addr < (tensor_addr + tensor_size):
copy_size = min(empty_size, tensor_addr + tensor_size - src_addr)
dest_buffer = self.buffer.narrow(0, dest_addr, copy_size)
src_buffer = self.buffer.narrow(0, src_addr, copy_size)
dest_buffer.data.copy_(src_buffer.data)
src_addr += copy_size
dest_addr += copy_size
self._replace_old_address_with_new(tensor_id, empty_addr)
tensor_index += 1
else:
tensor_index += 1
empty_addresses = sorted(self.contiguous_sizes.keys())
def _replace_old_address_with_new(self, tensor_id, new_address):
tensor = self.tensor_map[tensor_id]
tensor_size = tensor.numel()
tensor.data = self.buffer.narrow(0, new_address, tensor_size).data
self._release_tensor(tensor_id)
self._mark_as_occupied(new_address, tensor_size)
self.tensor_ids[new_address] = tensor_id
self.tensor_map[tensor_id] = tensor
self.tensor_addresses[tensor_id] = new_address
self.tensor_sizes[new_address] = tensor_size
def _get_new_tensor_address(self, size):
tensor_address = None
for address, contiguous_size in self.contiguous_sizes.items():
if contiguous_size >= size and \
(tensor_address is None or \
contiguous_size < self.contiguous_sizes[tensor_address]):
tensor_address = address
assert tensor_address is not None, "address cannot be None"
return tensor_address
def _get_new_tensor(self, address, size):
available_contiguous_size = self.contiguous_sizes[address]
assert size <= available_contiguous_size, \
f"Tensor numel {size} is large than available contiguous size {available_contiguous_size}"
self.count += 1
new_tensor = self.buffer.narrow(0, address, size)
tensor_id = id(new_tensor)
self.tensor_addresses[tensor_id] = address
self.tensor_sizes[address] = size
self.tensor_ids[address] = tensor_id
self.tensor_map[tensor_id] = new_tensor
self._mark_as_occupied(address, size)
return new_tensor
def _largest_contiguous(self):
if len(self.contiguous_sizes) > 0:
return max([size for _, size in self.contiguous_sizes.items()])
else:
return 0
def _mark_as_occupied(self, address, size):
available_contiguous_size = self.contiguous_sizes[address]
del self.contiguous_sizes[address]
if available_contiguous_size != size:
self.contiguous_sizes[address + size] = available_contiguous_size - size
self.largest_contiguous = self._largest_contiguous()
#Linear Module to use with ZeRO Stage 3 to allow for parameter memory release
#after the module execution during forward
#Instead of saving variables using save_for_backward, we save variable ids
#Allowing us to retrive the variable without creating pointer to it
#Which allows for underlying tensor to be garbage collected
#When partitioned as needed by the Zero Stage 3 optimizer
#TODO instead of patching Linear module, we could patch the ctx.save_for_backward
#ctx.saved_tensors so that this approach works for all nn modules that are built upon
#torch.nn.function. However the issue is that many modules uses C++ implementations
#which does not have pytroch implementation. Eg torch.addmm which acts as a funcitonal
#when implemeted outside of torch.autograd.Function
import math
import torch
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.nn.modules.module import Module
tensor_map = {}
class LinearFunctionForZeroStage3(torch.autograd.Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
#print("In ZeRO Linear Function")
weight_id = id(weight)
bias_id = id(bias)
#ctx.save_for_backward(input, weight, bias)
ctx.save_for_backward(input, torch.tensor(weight_id), torch.tensor(bias_id))
tensor_map[weight_id] = weight
tensor_map[bias_id] = bias
if input.dim() == 2 and bias is not None:
# fused op is marginally faster
ret = torch.addmm(bias, input, weight.t())
else:
output = input.matmul(weight.t())
if bias is not None:
output += bias
ret = output
return ret
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
#input, weight, bias = ctx.saved_tensors
input, weight_id, bias_id = ctx.saved_tensors
weight = tensor_map[weight_id.item()]
bias = tensor_map[bias_id.item()]
grad_input = grad_weight = grad_bias = None
#print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}")
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
#print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
grad_input = grad_output.matmul(weight)
#print(f"Computed grad input {grad_input.shape}")
if ctx.needs_input_grad[1]:
#print("Computing grad weight")
dim = grad_output.dim()
if dim > 2:
grad_weight = grad_output.view(-1,
grad_output.shape[-1]).t().matmul(
input.view(-1,
input.shape[-1]))
else:
grad_weight = grad_output.t().matmul(input)
#print(f"Computed grad weight grad_weight {grad_weight.shape}")
if bias is not None and ctx.needs_input_grad[2]:
#print("Computing grad bias")
grad_bias = grad_output.sum(0)
#print("Done computing grad bias")
#print("needs bias")
#print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}")
return grad_input, grad_weight, grad_bias
class LinearModuleForZeroStage3(Module):
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
The weights are pre-transposed and stored as A^T instead of transposing during each
forward. Memory savings proportional to the parameter size.
Args:
in_features: size of each input sample
out_features: size of each output sample
bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True``
Shape:
- Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
additional dimensions and :math:`H_{in} = \text{in\_features}`
- Output: :math:`(N, *, H_{out})` where all but the last dimension
are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
Attributes:
weight: the learnable weights of the module of shape
:math:`(\text{out\_features}, \text{in\_features})`. The values are
initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
:math:`k = \frac{1}{\text{in\_features}}`
bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in\_features}}`
Examples::
>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
super(LinearModuleForZeroStage3, self).__init__()
print("Building ZeRO module")
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tensor) -> Tensor:
return LinearFunctionForZeroStage3.apply(input, self.weight, self.bias)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features,
self.out_features,
self.bias is not None)
import os
import time
import types
from enum import Enum
import functools
import itertools
import torch
from torch.distributed.distributed_c10d import _get_global_rank
from deepspeed.runtime.zero.linear import LinearModuleForZeroStage3, LinearFunctionForZeroStage3
from deepspeed.runtime.utils import see_memory_usage
from deepspeed.utils import log_dist, init_distributed
param_count = 0
def print_rank_0(message, debug=False, force=False):
if torch.distributed.get_rank() == 0 and (debug or force):
print(message)
def is_zero_param(parameter):
return hasattr(parameter, 'ds_id')
def _init_external_params(module):
if not hasattr(module, '_external_params'):
module._external_params = {}
def external_parameters(self):
if not hasattr(self, '_external_params'):
self._external_params = {}
return self._external_params.items()
def all_parameters(self):
return itertools.chain(self.named_parameters(self,
recurse=False),
external_parameters(self))
module.ds_external_parameters = types.MethodType(external_parameters, module)
module.all_parameters = types.MethodType(all_parameters, module)
def register_external_parameter(module, parameter):
"""Instruct DeepSpeed to coordinate ``parameter``'s collection and partitioning in
the forward and backward passes of ``module``.
This is used when a parameter is accessed outside of its owning module's
``forward()``. DeepSpeed must know to collect it from its partitioned
state and when to release the memory.
.. note::
This is only applicable to training with ZeRO stage 3.
Args:
module (``torch.nn.Module``): The module that requires ``parameter`` in its forward pass.
parameter (``torch.nn.Parameter``): The parameter to register.
Raises:
RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``.
Examples
========
#. Register a weight that is used in another module's forward pass (line 6).
Parameter ``layer1.weight`` is used by ``layer2`` (line 11).
.. code-block:: python
:linenos:
:emphasize-lines: 6,11
class ModuleZ3(torch.nn.Module):
def __init__(self, *args):
super().__init__(self, *args)
self.layer1 = SomeLayer()
self.layer2 = OtherLayer()
deepspeed.zero.register_external_parameter(self, self.layer1.weight)
def forward(self, input):
x = self.layer1(input)
# self.layer1.weight is required by self.layer2.forward
y = self.layer2(x, self.layer1.weight)
return y
"""
if not isinstance(parameter, torch.nn.Parameter):
raise RuntimeError('Parameter is not a torch.nn.Parameter')
if not hasattr(module, '_external_params'):
_init_external_params(module)
key = id(parameter)
module._external_params[key] = parameter
class ZeroParamType(Enum):
# same as regular pytorch parameters
NORMAL = 1
# parameters are partitioned across data parallel process
PARTITIONED = 2
# the parameter is held with a unique process rank
# and is not available on all other process
REMOTE = 3
class ZeroParamStatus(Enum):
# parameters are fully present and ready for use on all processes
AVAILABLE = 1
# parameters are either partitioned or remote in some or all process
NOT_AVAILABLE = 2
# parameters are being gathered.
INFLIGHT = 3
_orig_torch_empty = torch.empty
def empty_cuda_tensor(*size, **kwargs):
if not 'device' in kwargs.keys():
kwargs['device'] = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
tensor = _orig_torch_empty(*size, **kwargs)
if tensor.is_floating_point():
return tensor.half()
else:
return tensor
def new_cuda_tensor(cls, *args):
device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
tensor = torch.ones((1, 1), device=device).new_empty(*args).half()
if tensor.is_floating_point():
return tensor.half()
else:
return tensor
reuse_buffers = False
temp_contiguous_tensor = None
empty_buffers = {}
# Inserts _post_init_method at the end of init method
# for all sub classes of torch.nn.Module
class InsertPostInitMethodToModuleSubClasses(object):
def __init__(self, enabled=True, mem_efficient_linear=True):
self.mem_efficient_linear = mem_efficient_linear
self.enabled = enabled
def __enter__(self):
if not self.enabled:
return
def partition_after(f):
@functools.wraps(f)
def wrapper(module, *args, **kwargs):
print_rank_0(f'Before initializing {module.__class__.__name__}',
force=False)
f(module, *args, **kwargs)
self._post_init_method(module)
print_rank_0(
f'After initializing followed by post init for {module.__class__.__name__}',
force=False)
return wrapper
def _enable_class(cls):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)
def _init_subclass(cls, **kwargs):
cls.__init__ = partition_after(cls.__init__)
# Replace .__init__() for all existing subclasses of torch.nn.Module
for subclass in torch.nn.modules.module.Module.__subclasses__():
_enable_class(subclass)
# holding on to the current __init__subclass__ for exit
torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__
torch.Tensor.__old_new__ = torch.Tensor.__new__
# Replace .__init__() for future subclasses of torch.nn.Module
torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
torch.Tensor.__new__ = new_cuda_tensor
torch.empty = empty_cuda_tensor
if self.mem_efficient_linear:
self.linear_bk = torch.nn.functional.linear
torch.nn.functional.linear = LinearFunctionForZeroStage3.apply
def __exit__(self, exc_type, exc_value, traceback):
if not self.enabled:
return
def _disable_class(cls):
cls.__init__ = cls._old_init
# Replace .__init__() for all existing subclasses of torch.nn.Module
for subclass in torch.nn.modules.module.Module.__subclasses__():
_disable_class(subclass)
# Replace .__init__() for future subclasses of torch.nn.Module
torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
torch.Tensor.__new__ = torch.Tensor.__old_new__
torch.empty = _orig_torch_empty
if self.mem_efficient_linear:
torch.nn.functional.linear = self.linear_bk
# Now that we cleaned up the metaclass injection, raise the exception.
if exc_type is not None:
return False
# To be implemented by inheriting classes
def _post_init_method(self, module):
pass
# Replaces all parameters in module with Scattered Parameters
class Init(InsertPostInitMethodToModuleSubClasses):
param_id = 0
def __init__(self,
module=None,
data_parallel_group=None,
mem_efficient_linear=True,
remote_device=None,
pin_memory=False,
enabled=True):
"""A context to enable massive model construction for training with
ZeRO-3. Models are automatically partitioned (or, sharded) across the
system and converted to half precision.
Args:
module (``torch.nn.Module``, optional): If provided, partition the model as
if it was constructed in the context.
data_parallel_group (``torch.distributed`` process group, optional):
The group of processes to partition among. Defaults to all processes.
mem_efficient_linear (bool, optional): Replace
torch.nn.functional.linear with an implementation that allows
DeepSpeed to partition parameters. Defaults to ``True``.
remote_device (string, optional): The device to store model
weights. Passing ``"cpu"`` will create the model in CPU
memory. The model may still be moved to GPU if
``cpu_offload_param`` is ``False`` in the config provided to
:meth:`deepspeed.initialize`. Defaults to the local GPU.
pin_memory (bool, optional): Potentially increase performance by
using pinned memory for model weights. ``remote_device`` must be
``"cpu"``. Defaults to ``False``.
enabled (bool, optional): If ``False``, this context has no
effect. Defaults to ``True``.
This context accelerates model initialization and enables models that
are too large to allocate in their entirety in CPU memory. It has the
following effects:
#. allocates tensors to either GPU or CPU memory
#. converts floating point tensors to half precision
#. immediately partitions tensors among the group of data-parallel devices
#. (*optional*) replaces ``torch.nn.functional.linear`` with a more
memory-efficient implementation
These modifications allow for models that exceed the size of local CPU/GPU
memory, but fit within the total system memory (*i.e.*, aggregate CPU
or GPU memory) across all nodes. Consider initializing a model with one
trillion parameters, whose weights occupy two terabytes (TB) in half
precision. The initial CPU allocation in full precision requires 4TB of
memory *per process*, and so a system with 8 GPUs per node would need 32TB of
CPU memory due to data-parallel redundancies. Instead, by immediately
partitioning tensors we remove the redundancies. The result is that
regardless of the number of GPUs, we still only require the original 4TB. This
allows for a linear increase in model size with the aggregate system memory.
For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion
parameter model with 4 nodes and 32 GPUs.
.. note::
Initializes ``torch.distributed`` if it has not already been done so.
See :meth:`deepseed.init_distributed` for more information.
.. note::
Can also be used as a decorator:
.. code-block:: python
@deepspeed.zero.Init()
def get_model():
return MyLargeModel()
.. note::
Only applicable to training with ZeRO-3.
Examples
--------
#. Allocate a model and partition it among all processes:
.. code-block:: python
with deepspeed.zero.Init():
model = MyLargeModel()
#. Allocate a model in pinned CPU memory and partition it among a subgroup of processes:
.. code-block:: python
with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
remote_device="cpu",
pin_memory=True):
model = MyLargeModel()
#. Partition an already-allocated model in CPU memory:
.. code-block:: python
model = deepspeed.zero.Init(module=model)
"""
super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear)
if not torch.distributed.is_initialized():
init_distributed()
assert torch.distributed.is_initialized(), "Parameters cannot be scattered without initializing torch.distributed"
if data_parallel_group is None:
self.ds_process_group = torch.distributed.group.WORLD
else:
self.ds_process_group = data_parallel_group
self.rank = torch.distributed.get_rank(group=self.ds_process_group)
self.world_size = torch.distributed.get_world_size(group=self.ds_process_group)
#Local device is the device where the parameters are consumed
#It is the device where parameters are fully instantiated using allgather
self.local_device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
#Remote device is the device where parameter partiitons are stored
#It can be same as local_device or it could be CPU.
self.remote_device = self.local_device if remote_device is None else remote_device
self.pin_memory = pin_memory if (self.remote_device == 'cpu') else False
# If we are provided an already-allocated module to prepare.
if module is not None:
assert isinstance(module, torch.nn.Module)
for param in module.parameters(recurse=True):
if is_zero_param(param):
continue
self._convert_to_deepspeed_param(param)
param.partition()
def _post_init_method(self, module):
#see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False)
print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False)
see_memory_usage(
f"Before converting and partitioning parmas in {module.__class__.__name__}",
force=False)
global param_count
for name, param in module.named_parameters(recurse=False):
param_count += param.numel()
if not is_zero_param(param):
self._convert_to_deepspeed_param(param)
print_rank_0(
f"Partitioning param with ds id {param.ds_id} and shape {param.data.shape}"
)
param.partition()
see_memory_usage(
f"Param count {param_count}. After converting and partitioning parmas in {module.__class__.__name__}",
force=False)
def _convert_to_deepspeed_param(self, param):
# Partitioned, Normal, Remote
param.ds_param_type = ZeroParamType.PARTITIONED
# Replicated vs Partitioned vs Inflight
param.ds_status = ZeroParamStatus.AVAILABLE
# Stores the shape of the original tensor
param.ds_shape = param.shape
# Stores the number of elements in the original parmaeter without padding
param.ds_numel = param.numel()
# Stores the paritioned copy of the tensor
param.ds_tensor = None
# Keeps track of how many active sub-modules need this param at any given point in time
param.ds_active_sub_modules = 0
# If this flag is true, then the parameters are replicated throughput training
# And only partitioned before the step
param.ds_persist = False
# The group that the parameter is scattered across.
param.ds_process_group = self.ds_process_group
# DeepSped Param ID
param.ds_id = Init.param_id
Init.param_id += 1
def all_gather(param_list=None, async_op=False, hierarchy=0):
cls = param
if param_list is None:
param_list = [cls]
return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)
def partition(param_list=None, hierarchy=0, has_been_updated=False):
cls = param
print_rank_0(
f"{'--'*hierarchy}----Partitioning param with id {cls.ds_id} dev {cls.device} shape {cls.shape}"
)
if param_list is None:
param_list = [cls]
self._partition(param_list, has_been_updated=has_been_updated)
def reduce_gradients_at_owner(param_list=None, hierarchy=0):
cls = param
if param_list is None:
param_list = [cls]
print_rank_0(
f"{'--'*hierarchy}----Reducing Gradients for param with ids {[param.ds_id for param in param_list]} to owner"
)
self._reduce_scatter_gradients(param_list)
def partition_gradients(param_list=None,
partition_buffers=None,
hierarchy=0,
accumulate=False):
cls = param
print_rank_0(
f"{'--'*hierarchy}----Partitioning param gradient with id {cls.ds_id}")
if param_list is None:
param_list = [cls]
if isinstance(partition_buffers, torch.Tensor):
partition_buffers = [partition_buffers]
self._partition_gradients(param_list,
partition_buffers=partition_buffers,
accumulate=accumulate)
def aligned_size():
return self._aligned_size(param)
def padding_size():
return self._padding_size(param)
# Collectives for gathering and partitioning parameters
param.all_gather = all_gather
param.partition = partition
# Collective for averaging gradients
param.reduce_gradients_at_owner = reduce_gradients_at_owner
param.partition_gradients = partition_gradients
# Partitioning size utilities
param.aligned_size = aligned_size
param.padding_size = padding_size
def _aligned_size(self, param):
return param.ds_numel + self._padding_size(param)
def _padding_size(self, param):
remainder = param.ds_numel % self.world_size
return (self.world_size - remainder) if remainder else 0
def _all_gather(self, param_list, async_op=False, hierarchy=None):
handles = []
all_gather_list = []
for param in param_list:
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if async_op:
handle = self._allgather_param(param,
async_op=async_op,
hierarchy=hierarchy)
param.ds_status = ZeroParamStatus.INFLIGHT # if async_op else ZeroParamStatus.AVAILABLE
handles.append(handle)
else:
all_gather_list.append(param)
if not async_op:
ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy)
for param in all_gather_list:
param.ds_status = ZeroParamStatus.AVAILABLE
return ret_value
return handles
def _partition(self, param_list, force=False, has_been_updated=False):
for param in param_list:
#print_rank_0(f"Before Partitioning Param {param.ds_id}")
#self._param_status(param)
self._partition_param(param, has_been_updated=has_been_updated)
param.ds_status = ZeroParamStatus.NOT_AVAILABLE
#if param.ds_tensor is not None:
# assert id(param.data) == id(param.ds_tensor.data), \
# "After the parameters are initially partitioned, make sure we are not recreating the partition."
#print_rank_0(f"After Partitioning Param {param.ds_id}")
# self._param_status(param)
def _partition_param(self, param, has_been_updated=False):
assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot parititon a param in flight"
global reuse_buffers
#print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}")
if param.ds_status is ZeroParamStatus.AVAILABLE:
print_rank_0(
f"Partitioning param id {param.ds_id} reuse buffers {reuse_buffers}",
force=False)
# if reuse_buffers and False:
# numel = buffer.numel()
# buffer = param.data.view(-1)
# print_rank_0(
# "Returning buffer for param {param.ds_id} with numel {param.ds_numel} to empty buffers",
# force=False)
# if numel in empty_buffers:
# empty_buffers[numel].append(buffer)
#if torch.distributed.get_rank():
# print(f"Releasing {param.data.numel()}")
if param.ds_tensor is not None and not has_been_updated:
#param.data = param.ds_tensor.data
#param.data does not store anything meaningful in partitioned state
param.data = torch.ones(1).half().to(param.device)
return
tensor_size = self._aligned_size(param)
partition_size = tensor_size // self.world_size
if param.ds_tensor is None:
partitioned_tensor = torch.zeros(partition_size,
dtype=param.dtype,
device=self.remote_device)
partitioned_tensor.requires_grad = False
if self.pin_memory:
partitioned_tensor = partitioned_tensor.pin_memory()
param.ds_tensor = partitioned_tensor
start = partition_size * self.rank
end = start + partition_size
one_dim_param = param.contiguous().view(-1)
if start < param.ds_numel and end <= param.ds_numel:
src_tensor = one_dim_param.narrow(0, start, partition_size)
param.ds_tensor.copy_(src_tensor)
#partitioned_tensor = src_tensor.clone().detach().to(self.remote_device)
else:
# partitioned_tensor = torch.zeros(partition_size,
# dtype=param.dtype,
# device=self.remote_device )
if start < param.ds_numel:
elements_to_copy = param.ds_numel - start
param.ds_tensor.narrow(0,
0,
elements_to_copy).copy_(
one_dim_param.narrow(
0,
start,
elements_to_copy))
#print(f"Remote device {self.remote_device}")
#param.ds_tensor = partitioned_tensor
#param.data = param.ds_tensor.data
#param.data does not store anything meaningful in partitioned state
param.data = torch.ones(1).half().to(param.device)
print_rank_0(
f"ID {param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}"
)
def _param_status(self, param):
if param.ds_tensor is not None:
print_rank_0(
f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned numel {param.ds_tensor.numel()}, data numel {param.data.numel()}"
)
else:
print_rank_0(
f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned ds_tensor {param.ds_tensor}, data numel {param.data.numel()}"
)
def _allgather_param(self, param, async_op=False, hierarchy=0):
partition_size = param.ds_tensor.numel()
tensor_size = partition_size * self.world_size
aligned_param_size = self._aligned_size(param)
assert tensor_size == aligned_param_size, f'param id {param.ds_id} aligned size {aligned_param_size} does not match tensor size {tensor_size}'
print_rank_0(
f"{'--'* hierarchy}---- Before allocating Allgather param with id {param.ds_id} and status {param.ds_status} Partition Size {partition_size} and data shape {param.ds_shape}"
)
flat_tensor = torch.zeros(aligned_param_size,
dtype=param.dtype,
device=param.device).view(-1)
torch.cuda.synchronize()
print_rank_0(
f"{'--'* hierarchy}----Allgather param with id {param.ds_id} and status {param.ds_status} Partition Size {partition_size} and data shape {param.ds_shape}"
)
# if not flat_tensor.numel() > 100000:
# replicated_tensor = flat_tensor.narrow(0,
# 0,
# param.ds_numel).view(param.ds_shape)
# param.data = replicated_tensor.data
# return None
partitions = []
for i in range(self.world_size):
partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size))
if i == torch.distributed.get_rank(group=self.ds_process_group):
partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True)
handle = torch.distributed.all_gather(partitions,
partitions[self.rank],
group=self.ds_process_group,
async_op=async_op)
replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape)
param.data = replicated_tensor.data
return handle
def _allgather_params(self, param_list, hierarchy=0):
if len(param_list) == 0:
return
partition_size = sum([param.ds_tensor.numel() for param in param_list])
tensor_size = partition_size * self.world_size
flat_tensor = torch.empty(tensor_size,
dtype=param_list[0].dtype,
device=self.local_device)
flat_tensor.requres_grad = False
partitions = []
for i in range(self.world_size):
start = partition_size * i
partitions.append(flat_tensor.narrow(0, start, partition_size))
if i == self.rank:
offset = 0
for param in param_list:
param_numel = param.ds_tensor.numel()
partitions[i].narrow(0,
offset,
param_numel).copy_(param.ds_tensor.data)
offset += param_numel
torch.distributed.all_gather(partitions,
partitions[self.rank],
group=self.ds_process_group,
async_op=False)
param_offset = 0
for param in param_list:
param_partition_size = param.ds_tensor.numel()
param_size = param.ds_numel
replicated_tensor = torch.empty(param.ds_shape,
dtype=param.dtype,
device=self.local_device)
for i in range(self.world_size):
start = i * partition_size
param_start = i * param_partition_size
if param_start < param_size:
numel_to_copy = min(param_size - param_start, param_partition_size)
part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy)
replicated_tensor.view(-1).narrow(0,
param_start,
numel_to_copy).copy_(part_to_copy)
#param_offset += param.data.numel()
param_offset += param.ds_tensor.numel()
param.data = replicated_tensor.data
return None
def _reduce_scatter_gradients(self, param_list):
#print_rank_0([param.grad for param in param_list])
#assert any([param.grad is None for param in param_list]), "None gradients cannot be reduce scattered"
handles_and_reduced_partitions = []
for param in param_list:
assert param.grad.numel(
) == param.ds_numel, f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter gradients whose size is not same as the params"
handles_and_reduced_partitions.append(self._reduce_scatter_gradient(param))
for param, (handle, reduced_partition) in zip(param_list, handles_and_reduced_partitions):
if handle is not None:
handle.wait()
# some ranks may have partitions that are padded to go beyond the grad size.
# For these ranks the output of reduce scatter is a separate buffer and needs
# to be copied in
partition_size = param.ds_tensor.numel()
start = self.rank * partition_size
end = start + partition_size
#print_rank_0("REduce scatter was executed for praam {param.ds_id}")
if start < param.ds_numel and end > param.ds_numel:
elements = param.ds_numel - start
param.grad.view(-1).narrow(0,
start,
elements).copy_(
reduced_partition.narrow(0,
0,
elements))
def _reduce_scatter_gradient(self, param):
partition_size = param.ds_tensor.numel()
#output = torch.empty(partition_size, dtype=param.dtype, device=param.device)
total_size = partition_size * self.world_size
input_list = []
for i in range(self.world_size):
start = i * partition_size
end = start + partition_size
#print("before reduce scatter gradients")
if start < param.ds_numel and end <= param.ds_numel:
input = param.grad.view(-1).narrow(0, start, partition_size)
else:
input = torch.zeros(partition_size,
dtype=param.dtype,
device=param.device)
if start < param.ds_numel:
elements = param.ds_numel - start
input.narrow(0,
0,
elements).copy_(
param.grad.view(-1).narrow(0,
start,
elements))
#print("after reduce scatter gradients")
input_list.append(input)
rank = torch.distributed.get_rank(group=self.ds_process_group)
handle = torch.distributed.reduce_scatter(input_list[rank],
input_list,
group=self.ds_process_group,
async_op=True)
return handle, input_list[rank]
def _partition_gradients(self, param_list, partition_buffers=None, accumulate=False):
if partition_buffers is None:
partition_buffers = [None] * len(param_list)
for param, partition_buffer in zip(param_list, partition_buffers):
self._partition_gradient(param,
partition_buffer=partition_buffer,
accumulate=accumulate)
def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
#import pdb;pdb.set_trace()
# param.grad=None
# param.grad.test()
print_rank_0(
f"Partitioning param {id(param)} gradient of size {param.grad.numel()} type {param.grad.dtype} part_size {param.ds_tensor.numel()}"
)
see_memory_usage("Before partitioning gradients", force=False)
partition_size = param.ds_tensor.numel()
if partition_buffer is None:
assert not accumulate, "No buffer to accumulate to"
partition_buffer = torch.zeros(partition_size,
dtype=param.dtype,
device=param.device)
else:
assert partition_buffer.numel() >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}"
rank = torch.distributed.get_rank(group=self.ds_process_group)
start = partition_size * rank
end = start + partition_size
dest_tensor = partition_buffer.view(-1).narrow(0, 0, partition_size)
#print("before partition gradients")
if start < param.ds_numel:
elements = min(param.ds_numel - start, partition_size)
dest_tensor = partition_buffer.view(-1).narrow(0, 0, elements)
src_tensor = param.grad.view(-1).narrow(0, start, elements)
# just copy the grad partition to the buffer
if not accumulate:
dest_tensor.copy_(src_tensor)
# if source and destinatoin are on same device,
# add to the provided buffer
elif src_tensor.device == dest_tensor.device:
dest_tensor.add_(src_tensor)
# if source and destination are on different device, copy first to src
# then add and move back to the destination. This seems to run faster
# when src is gpu and dest is cpu
# adding directly to cpu is very slow
else:
acc_tensor = torch.empty(src_tensor.numel(),
dtype=param.dtype,
device=param.device)
acc_tensor.copy_(dest_tensor)
acc_tensor.add_(src_tensor)
dest_tensor.copy_(acc_tensor)
# partition_buffer.view(-1).narrow(
# 0,
# 0,
# elements).copy_(param.grad.view(-1).narrow(0,
# start,
# elements))
#print("after partition gradients")
param.grad.data = dest_tensor.data
see_memory_usage("After partitioning gradients", force=False)
class GatheredParameters:
def __init__(self, param, modifier_rank=None, fwd_module=None, enabled=True):
"""A context that collects a parameter that was partitioned via a
:class:`deepspeed.zero.Init` context. The parameter is partitioned
again upon exit.
Args:
param (``torch.nn.Parameter``): The parameter to collect.
modifier_rank (int, optional): If specified, this rank's parameter will be
broadcasted after the context. This argument is required if ``param`` is
modified all processes should have a consistent view of the data. Defaults
to ``None``.
fwd_module (``torch.nn.Module``, optional): If specified, ``param`` will be
registered as an external parameter of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`.
enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``.
Examples
========
#. Allocate a partitioned module, initialize its weight on rank 0, and update all
processes.
.. code-block:: python
with deepspeed.zero.Init():
linear = torch.nn.Linear(1000,1000)
with deepspeed.zero.GatheredParameters(linear.weight,
modifier_rank=0):
if torch.distributed.get_rank() == 0:
linear.weight.zero_()
#. Collect a partitioned weight to pass to another module during
training. The parameter will be registered as an external parameter
and made available during the backward pass.
.. code-block:: python
:emphasize-lines: 6
def forward(self, input):
x = self.layer1(input)
# self.layer1.weight is required by self.layer2.forward
with deepspeed.zero.GatheredParameters(self.layer1.weight,
fwd_module=self):
y = self.layer2(x, self.layer1.weight)
return y
"""
self.enabled = enabled
if not enabled:
return
# This is a no-op, just return.
if not is_zero_param(param):
self.enabled = False
return
self.param = param
self.src_rank = None
if modifier_rank is not None:
if self.param.ds_process_group == torch.distributed.group.WORLD:
self.src_rank = modifier_rank
else:
# A group was specified; convert DP rank to global rank
self.src_rank = _get_global_rank(self.param.ds_process_group,
modifier_rank)
self.fwd_module = fwd_module
if self.fwd_module is not None:
# is a no-op if already registered
register_external_parameter(self.fwd_module, self.param)
def __enter__(self):
if not self.enabled:
return
self.param.all_gather()
def __exit__(self, *exc):
if not self.enabled:
return
if self.src_rank is not None:
torch.distributed.broadcast(self.param,
self.src_rank,
group=self.param.ds_process_group)
self.param.partition(has_been_updated=self.src_rank is not None)
from deepspeed.utils.logging import logger
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import os
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed.distributed_c10d import _get_global_rank
import torch.distributed as dist
import math
from torch._six import inf
from torch.autograd import Variable
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, ZeroParamType, _init_external_params, Init, is_zero_param
from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS
from deepspeed.ops.adam import DeepSpeedCPUAdam
import itertools
# Toggle this to true to enable correctness test
# with gradient partitioning and without
pg_correctness_test = False
try:
from apex_C import flatten
from apex_C import unflatten
except ImportError:
try:
_ = warned_flatten
except NameError:
logger.warning(
"apex was installed without --cpp_ext. Falling back to Python flatten and unflatten."
)
warned_flatten = True
from torch._utils import _flatten_dense_tensors as flatten
from torch._utils import _unflatten_dense_tensors as unflatten
def print_rank_0(message, debug=False, force=False):
if torch.distributed.get_rank() == 0 and (debug or force):
logger.info(message)
def input(msg):
return
def split_half_float_double(tensors):
dtypes = [
"torch.cuda.HalfTensor",
"torch.cuda.FloatTensor",
"torch.cuda.DoubleTensor"
]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
return buckets
def isclose(a, b, rtol=1e-09, atol=0.0):
return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)
def lcm(x, y):
from fractions import gcd # or can import gcd from `math` in Python 3
return x * y // gcd(x, y)
# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(tensor_list, alignment):
num_elements = 0
for tens in tensor_list:
num_elements = num_elements + tens.numel()
remaining = num_elements % alignment
if remaining:
elements_to_add = alignment - remaining
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
num_elements = num_elements + elements_to_add
else:
padded_tensor_list = tensor_list
return _flatten_dense_tensors(padded_tensor_list)
def move_to_cpu(tensor_list):
for tensor in tensor_list:
tensor.data = tensor.data.cpu()
def get_all_parameters(sub_module):
return itertools.chain(sub_module.named_parameters(recurse=False),
sub_module.ds_external_parameters())
#apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module,
functional,
backward_function,
output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
return functional.apply(module, backward_function, outputs)
else:
return outputs
#for each tensor in outputs run the forward_funciton and register backward_function as hook
def _apply_forward_and_backward_to_tensors_only(module,
forward_function,
backward_function,
outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
touched_output = _apply_forward_and_backward_to_tensors_only(
module,
forward_function,
backward_function,
output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
forward_function(outputs)
if outputs.requires_grad:
outputs.register_hook(backward_function)
return outputs
else:
return outputs
# TODO Needs to be implemented
class PrefetchCoordinator(object):
def __init__(self):
# step_id keeps track of the number of sub-modules invoked so far
# the step_id is tracking forward and backward sequence of sub-modules
self.step_id = 0
# stores the sequence of sub modules in forward+backward pass
self.sub_module_trace = []
# maps sub_module id to submodule objects
self.id_to_sub_module_map = {}
# stores the total number of parmeters in each sub_module
self.id_to_sub_module_size_map = {}
self.trace_completed = False
self.most_recent_sub_module_step = {}
# reuse distances
self.reuse_numel_for_step_id = {}
def record_trace(self, sub_module):
if not self.trace_completed:
self.sub_module_trace.append(sub_module.id)
self.id_to_sub_module_map[sub_module.id] = sub_module
def print_trace(self):
print_rank_0(
f"The module trace is : {[self.id_to_sub_module_map[module_id].id for module_id in self.sub_module_trace]}"
)
def increment_step(self, sub_module):
self.most_recent_sub_module_step[sub_module.id] = self.step_id
self.step_id += 1
def reset_step(self):
self.step_id = 0
# returns the next numel parameters that will be used next but are not available or inflight
def get_params_to_prefetch(self, sub_module, numel=2000000):
# numel_in_sub_module = 0
# for name, param in sub_module.named_parameters(recurse=False):
# numel_in_sub_module += param.ds_numel
# #if numel_in_sub_module < (numel // 2):
# return []
# tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing
if sub_module.id != self.sub_module_trace[self.step_id]:
print_rank_0(
f"Tracing failed. Prefetching is disabled at sub-module: {sub_module.id}"
)
return []
params_to_prefetch = []
total_numel_to_prefetch = 0
for i in range(self.step_id, len(self.sub_module_trace)):
module_id = self.sub_module_trace[i]
for _, param in get_all_parameters(self.id_to_sub_module_map[module_id]):
if param.ds_status is ZeroParamStatus.NOT_AVAILABLE and (
param.ds_id not in [p.ds_id for p in params_to_prefetch]):
params_to_prefetch.append(param)
total_numel_to_prefetch += param.ds_numel
#print_rank_0(f"Total numel to prefetch: {total_numel_to_prefetch}. Param: {param.ds_shape} and numel {param.ds_numel}, numel limit {numel}")
if total_numel_to_prefetch >= numel: # and total_numel_to_prefetch > (numel_in_sub_module // 2):
return params_to_prefetch
return params_to_prefetch
# checks if this sub_module will be used again and if so then returns the number of elements
# in the parameters used between this sub_module and the reuse of this sub_module
def get_reuse_distance_in_numel(self, sub_module, sub_module_step_id=None):
#assert is_forward is not None, "is_forward must be set to True for Forward Propagation and False for backward Propagation"
is_there_reuse = False
reuse_distance_in_numel = 1000000000000
# set the appropriate trace
trace = self.sub_module_trace
total_steps = len(trace)
if sub_module_step_id is None:
sub_module_step_id = self.most_recent_sub_module_step[sub_module.id]
# tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing
if sub_module.id != trace[sub_module_step_id]:
print_rank_0(
f"Tracing failed. Cannot tell if the sub_module: {sub_module.id} is reused"
)
return reuse_distance_in_numel
# return cached value
if sub_module_step_id in self.reuse_numel_for_step_id:
return self.reuse_numel_for_step_id[sub_module_step_id]
start_step = self.step_id
print_rank_0(f"Step id is {self.step_id} ")
for step_id in range(start_step, total_steps):
print_rank_0(f"Trace id {trace[step_id]} and sub_module id {sub_module.id}")
if sub_module.id == trace[step_id]:
end_step = step_id
is_there_reuse = True
reuse_distance_in_numel = self._distance_in_numel(
start_step,
end_step,
trace)
break
self.reuse_numel_for_step_id[sub_module_step_id] = reuse_distance_in_numel
return reuse_distance_in_numel
def _distance_in_numel(self, start_step, end_step, trace):
distance_in_numel = 0
for step_id in range(start_step, end_step):
module_id = trace[step_id]
for _, param in self.id_to_sub_module_map[module_id].named_parameters(recurse=False):
distance_in_numel += param.ds_numel
for _, param in self.id_to_sub_module_map[module_id].ds_external_parameters():
distance_in_numel += param.ds_numel
return distance_in_numel
class PartitionedParameterCoordinator(object):
def __init__(self,
comm_stream=None,
max_reuse_distance_in_numel=500000000,
max_available_parameters_in_numel=700000000):
self.in_flight_handles = []
self.params_in_flight = []
self.comm_stream = comm_stream if comm_stream is not None else torch.cuda.current_stream(
)
self.prefetch_coordinator = PrefetchCoordinator()
self.hierarchy = 0
self.total_available_parameter_numel = 0
self.max_available_parameters_in_numel = max_available_parameters_in_numel
# max distance between two use of the module beyond which module is released
self.max_reuse_distance_in_numel = max_reuse_distance_in_numel
def _increment_available_parameter_numel(self, increment):
self.total_available_parameter_numel += increment
def _decrement_available_parameter_numel(self, decrement):
self.total_available_parameter_numel -= decrement
'''-----------------------Tracing and Prefetching ---------------'''
def record_trace(self, sub_module):
self.prefetch_coordinator.record_trace(sub_module)
def finish_tracing(self, print_trace=False):
self.prefetch_coordinator.trace_completed = True
if print_trace:
self.prefetch_coordinator.print_trace()
# Pre fetches the parameters for sub_modules that comes after
# the current sub_module. This call is asynchronous
def prefetch_next_sub_modules(self, sub_module, numel=5000000):
params_to_prefetch = []
if not self.prefetch_coordinator.trace_completed:
return params_to_prefetch
# prefetch if there is no current prefetching in flight
if not self.in_flight_handles and self.total_available_parameter_numel < self.max_available_parameters_in_numel:
params_to_prefetch = self.prefetch_coordinator.get_params_to_prefetch(
sub_module,
numel=numel)
self._all_gather(params_to_prefetch, async_op=True)
for param in params_to_prefetch:
param.ds_status = ZeroParamStatus.INFLIGHT
# keeping track of number of elements consumed by available parmaeters
self._increment_available_parameter_numel(param.ds_numel)
self._print_prefetch_elements_info(sub_module, params_to_prefetch)
print_rank_0(
f"{'--' * self.hierarchy}--PreFetching parameters {[param.ds_id for param in params_to_prefetch]} and available {self.total_available_parameter_numel}, max limit {self.max_available_parameters_in_numel}",
force=False)
def _print_prefetch_elements_info(self, sub_module, params_to_prefetch):
sub_module_numel = 0.0
for name, param in sub_module.named_parameters(recurse=False):
sub_module_numel += param.ds_numel
numel_being_prefetched = 0
for param in params_to_prefetch:
numel_being_prefetched = param.ds_numel
print_rank_0(
f"{'--' * self.hierarchy}--PreFetching {numel_being_prefetched} numels and number of numel in the next sub module is {sub_module_numel}",
force=False)
def increment_step(self, sub_module):
self.prefetch_coordinator.increment_step(sub_module)
def reset_step(self):
self.prefetch_coordinator.reset_step()
'''----------------------------------------------------------------------'''
# Fetches the parameters in the sub_module
# This call is blocking
def fetch_sub_module(self, sub_module):
partitioned_params = []
params_in_flight = False
#print_rank_0(f"{'--' * self.hierarchy}Fetching params in module {sub_module.__class__.__name__}")
params_to_fetch = [
param for _,
param in sub_module.named_parameters(recurse=False)
]
if hasattr(sub_module, 'ds_external_parameters'):
print_rank_0(
f"{'--' * self.hierarchy}--Fetching external parameters {sub_module.ds_external_parameters()}"
)
params_to_fetch += [
param for _,
param in sub_module.ds_external_parameters()
]
# for _, param in sub_module.named_parameters(recurse=False):
for param in params_to_fetch:
param.ds_active_sub_modules += 1
print_rank_0(
f"{'--' * self.hierarchy}--Fetching parameters {param.ds_id} with active sub modules {param.ds_active_sub_modules}"
)
if param.ds_status == ZeroParamStatus.AVAILABLE:
print_rank_0(
f"{'--' * self.hierarchy}--Parameter {param.ds_id} is already available"
)
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
print_rank_0(
f"{'--' * self.hierarchy}--Parameter {param.ds_id} is being fetched")
partitioned_params.append(param)
# keeping track of number of elements consumed by available parmaeters
self._increment_available_parameter_numel(param.ds_numel)
print_rank_0(f"Incrementing with parameter id {param.ds_id}")
if param.ds_status == ZeroParamStatus.INFLIGHT:
params_in_flight = True
print_rank_0(
f"{'--' * self.hierarchy}--Parameters {param.ds_id} is already in flight (prefetched)"
)
self.hierarchy += 1
# parameters are partitioned and need to be allgathered
self._all_gather(partitioned_params, async_op=True)
# parameters are inflight and communication needs to be completed
if partitioned_params or params_in_flight:
self._synchronize_communication()
for _, param in sub_module.named_parameters(recurse=False):
param.ds_status = ZeroParamStatus.AVAILABLE
#print(f"Param id {param.ds_id}, Shape {param.shape}, device {param.device} ")
#print_rank_0(f"After fetching (id, shape, device): {[(param.ds_id, param.shape, param.device) for param in sub_module.named_parameters(recurse=False)]}")
def release_sub_module(self, sub_module):
self.hierarchy -= 1
print_rank_0(
f"{'--' * self.hierarchy}Releasing params in module {sub_module.__class__.__name__}"
)
params_to_release = [
param for _,
param in sub_module.named_parameters(recurse=False)
]
if hasattr(sub_module, 'ds_external_parameters'):
#print_rank_0(f"Releasing external parameters {sub_module.ds_external_parameters()}")
params_to_release += [
param for _,
param in sub_module.ds_external_parameters()
]
# for _, param in sub_module.named_parameters(recurse=False):
for param in params_to_release:
param.ds_active_sub_modules -= 1
if not param.ds_active_sub_modules and not self._keep_for_later(
sub_module) and not param.ds_persist:
print_rank_0(
f"{'--' * self.hierarchy}--Releasing parameters {param.ds_id} with numel {param.numel()} active sub modules {param.ds_active_sub_modules} and keep for later {self._keep_for_later(sub_module)}"
)
# Keeping track of number of elements that are consumed by available parameters
self._decrement_available_parameter_numel(param.ds_numel)
see_memory_usage(
f"Before releasing param {param.ds_id} with numel{param.numel()}",
force=False)
param.partition(hierarchy=self.hierarchy)
see_memory_usage(
f"After releasing param {param.ds_id} has numel{param.numel()} ",
force=False)
param.ds_status = ZeroParamStatus.NOT_AVAILABLE
else:
print_rank_0(
f"{'--' * self.hierarchy}--Did not release parameters {param.ds_id} with numel {param.numel()} with active sub modules {param.ds_active_sub_modules}, keep for later {self._keep_for_later(sub_module)} and persistence {param.ds_persist}"
)
def release_and_reset_parameter(self, param):
param.ds_active_sub_modules = 0
if param.ds_status == ZeroParamStatus.AVAILABLE:
print_rank_0(
f"Releasing unpartitioned {param.ds_id} active sub-modules {param.ds_active_sub_modules} size {param.ds_numel} and persisitence {param.ds_persist}"
)
self._decrement_available_parameter_numel(param.ds_numel)
param.partition()
def _keep_for_later(self, sub_module):
if not self.prefetch_coordinator.trace_completed:
return False
reuse_distance_in_numel = self.prefetch_coordinator.get_reuse_distance_in_numel(
sub_module)
#print_rank_0(f"Reuse distance and numel for sub_module id {sub_module.id} is {reuse_distance_in_numel}")
return reuse_distance_in_numel < self.max_reuse_distance_in_numel
def _all_gather(self, partitioned_params, async_op=False):
with torch.cuda.stream(self.comm_stream):
handles = partitioned_params[0].all_gather(
param_list=partitioned_params,
async_op=async_op,
hierarchy=self.hierarchy) if partitioned_params else None
if handles is not None:
self.in_flight_handles.extend(handles)
self.params_in_flight.extend(partitioned_params)
def _synchronize_communication(self, synchronize_streams=True):
assert len(self.params_in_flight) == len(self.in_flight_handles)
for handle, param in zip(self.in_flight_handles, self.params_in_flight):
if handle is not None:
with torch.cuda.stream(self.comm_stream):
handle.wait()
param.ds_status = ZeroParamStatus.AVAILABLE
self.comm_stream.synchronize()
torch.cuda.synchronize() if synchronize_streams else None
self.in_flight_handles = []
self.params_in_flight = []
class PreBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, outputs):
ctx.module = module
ctx.pre_backward_function = pre_backward_function
module.applied_pre_backward = False
#print(f"After Forward: {ctx.module.__class__.__name__}")
outputs = outputs.detach()
return outputs
@staticmethod
def backward(ctx, *args):
#print(f"Before Backward: {ctx.module.__class__.__name__}")
ctx.pre_backward_function(ctx.module)
return (None, None) + args
class PostBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, output):
ctx.module = module
if output.requires_grad:
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
#Should only cause increase in memory not correctness issue
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
# ctx.view=True
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
#if module.ds_grads_remaining == 0:
# print(f"Before Forward: {ctx.module.__class__.__name__}")
module.ds_grads_remaining += 1
ctx.pre_backward_function = pre_backward_function
output = output.detach()
return output
@staticmethod
def backward(ctx, *args):
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
if ctx.module.ds_grads_remaining == 0:
ctx.pre_backward_function(ctx.module)
#print(f"After Backward: {ctx.module.__class__.__name__}")
return (None, None) + args
INITIAL_MICRO_STEP_ID = -1
class FP16_DeepSpeedZeroOptimizer_Stage3(object):
"""
DeepSpeedZeroOptimizer designed to reduce the memory footprint
required for training large deep learning models.
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
For usage examples, refer to TODO: DeepSpeed Tutorial
"""
def __init__(self,
module,
init_optimizer,
timers,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
contiguous_gradients=True,
reduce_bucket_size=500000000,
prefetch_bucket_size=50000000,
max_reuse_distance=1000000000,
max_live_parameters=1000000000,
param_persistence_threshold=100000,
dp_process_group=None,
reduce_scatter=True,
overlap_comm=False,
cpu_offload_optimizer_state=False,
cpu_offload_params=False,
cpu_offload_use_pin_memory=False,
sub_group_size=1000000000000,
mpu=None,
clip_grad=0.0,
allreduce_always_fp32=False,
postscale_gradients=True,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1,
elastic_checkpoint=False):
see_memory_usage("Stage 3 intialize begining", force=True)
if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
logger.info(f"Allgather bucket size {prefetch_bucket_size}")
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add ne552w fused optimizer later
# differences from apex.fp16_utils:
# - assume all model params in fp16
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
if not all(is_zero_param(p) for p in module.parameters()):
group = None
if mpu:
group = mpu.get_data_parallel_group()
Init(module=module, data_parallel_group=group)
for m in module.modules():
_init_external_params(m)
self.module = module
self.elastic_checkpoint = elastic_checkpoint
self.overlap_comm = overlap_comm
if self.overlap_comm:
self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda()
######################cpu offload setup##################################
self.cpu_offload = cpu_offload_optimizer_state
self.cpu_offload_use_pin_memory = cpu_offload_use_pin_memory
if cpu_offload_params:
assert cpu_offload_optimizer_state, "parameter offload is only available with optimizer state offload"
self.cpu_offload_params = cpu_offload_optimizer_state and cpu_offload_params
self.deepspeed_adam_offload = (self.cpu_offload
and type(init_optimizer) == DeepSpeedCPUAdam)
self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu'
############################################################################
see_memory_usage("Before Partitioned Parameter Coordinator", force=True)
fetch_stream = torch.cuda.Stream() if self.overlap_comm else None
self.param_coordinator = PartitionedParameterCoordinator(
comm_stream=fetch_stream,
max_reuse_distance_in_numel=int(max_reuse_distance),
max_available_parameters_in_numel=int(max_live_parameters))
see_memory_usage("After Partitioned Parameter Coordinator", force=True)
#self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream())
#-------------Stage 3 Setup-------------------#
# parameters smaller than the threshold will be collectively gathered at the
# end of the optimizer step and will be kept till the end of the backward pass
# TODO maybe worth just replicating these parameters and doing all reduce for them
self.persistence_threshold = int(param_persistence_threshold)
self.persistent_parameters = self.persistent_parameters()
self.setup_zero_stage3_hooks()
#resetting ds_tensor just in case parameters have been changed after initialization
#example .half() or .to()
#self.reset_ds_tensor()
#---------------------------------------------#
self.timers = timers
self.reduce_scatter = reduce_scatter
self.dp_process_group = dp_process_group
self.partition_count = dist.get_world_size(group=self.dp_process_group)
if mpu is None:
self.model_parallel_group = None
self.model_parallel_rank = 0
else:
self.model_parallel_group = mpu.get_model_parallel_group()
self.model_parallel_rank = mpu.get_model_parallel_rank()
self.overflow = False
self.clip_grad = clip_grad
self.allreduce_always_fp32 = allreduce_always_fp32
self.gradient_predivide_factor = gradient_predivide_factor
self.postscale_gradients = postscale_gradients
self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = INITIAL_MICRO_STEP_ID
if self.reduce_scatter:
assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled"
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled"
assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled"
# Holds the mode parameter
# The param.data may not hold any meaningful data
# when param's status is NOT_AVAILABLE or IN_FLGHT
self.fp16_groups = []
# Hold partitioned parameters
self.fp16_partitioned_groups = []
# Holds a fused and flattened copy of the parameters
self.fp16_partitioned_groups_flat = []
#a single 32-bit partition of the parallel partitioned parameters
#that this process will update
self.fp32_partitioned_groups_flat = []
# number of elements per partition in each group
self.partition_size = []
self.all_reduce_print = False
self.prefetch_elements = int(prefetch_bucket_size)
# padding on each partition for alignment purposes
self.groups_padding = []
self.sub_group_size = sub_group_size
self.sub_group_to_group_id = {}
see_memory_usage("Before creating fp16 partitions", force=True)
#self._create_fp16_partitions()
self._create_fp16_partitions_with_defragmentation()
num_fp16_subgroups = len(self.fp16_partitioned_groups_flat)
see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}",
force=True)
see_memory_usage("Before creating fp32 partitions", force=True)
self._create_fp32_partitions()
see_memory_usage("After creating fp32 partitions", force=True)
see_memory_usage("Before initializing optimizer states", force=True)
self.initialize_optimizer_states()
see_memory_usage("After initializing optimizer states", force=True)
if dist.get_rank() == 0:
logger.info(f"optimizer state initialized")
self.reduce_bucket_size = int(reduce_bucket_size)
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.reduction_stream = torch.cuda.Stream(
) if self.overlap_comm else torch.cuda.current_stream()
self.callback_queued = False
self.copy_grad_stream = torch.cuda.Stream()
self.param_dict = {}
# map between param_id and bool to specify if a param is in this partition
self.is_param_in_current_partition = {}
self.contiguous_gradients = contiguous_gradients
self.extra_large_param_to_reduce = None
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.elements_in_ipg_bucket = 0
self.params_already_reduced = []
self._release_ipg_buffers()
self.previous_reduced_grads = None
# simplified param id
self.param_id = {}
count = 0
for i, params_group in enumerate(self.fp16_groups):
for param in params_group:
unique_id = id(param)
self.param_id[unique_id] = count
self.param_dict[count] = param
self.params_already_reduced.append(False)
count = count + 1
#Largest partitioned param
largest_partitioned_param_numel = self._get_largest_partitioned_numel()
see_memory_usage(f"Before Set Grad positions", force=True)
self.grad_position = {}
self.set_grad_positions()
see_memory_usage(f"Before CPU Offload initialization", force=True)
self.grads_in_partition = None
if self.cpu_offload:
self.accumulated_grads_in_cpu = {}
self.norm_for_param_grads = {}
self.local_overflow = False
self.temp_grad_buffer_for_gpu_offload = torch.zeros(
largest_partitioned_param_numel,
device=torch.cuda.current_device()).half()
self.temp_grad_gpu_buffer = torch.zeros(
largest_partitioned_param_numel,
device=torch.cuda.current_device()).half()
see_memory_usage(f"After CPU Offload initialization", force=True)
# stores if a partition has been reduced in this step
self.is_partition_reduced = {}
# stores if a grad in a partition has been computed or not
self.is_grad_computed = {}
# will store the averaged gradients required by this parititon
self.averaged_gradients = {}
#creates backward hooks for gradient partitioning
self.create_reduce_and_remove_grad_hooks()
#exit(0)
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
if dynamic_loss_args is None:
self.loss_scaler = DynamicLossScaler()
else:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
self.dynamic_loss_scale = True
else:
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(scale=static_loss_scale)
self.cur_iter = 0
self.debug_fp16_grads = [{} for _ in self.fp16_groups]
if dist.get_rank(group=self.dp_process_group) == 0:
see_memory_usage(f"After initializing ZeRO optimizer", force=True)
def _get_largest_partitioned_numel(self):
largest_partitioned_param_numel = 0
for partitioned_params_group in self.fp16_partitioned_groups:
for partitioned_param in partitioned_params_group:
if partitioned_param.numel() > largest_partitioned_param_numel:
largest_partitioned_param_numel = partitioned_param.numel()
return largest_partitioned_param_numel
def _create_fp16_partitions(self):
dist.barrier()
partition_id = dist.get_rank(group=self.dp_process_group)
# loop to deal with groups
for j, param_group in enumerate(self.optimizer.param_groups):
sub_groups = self._create_fp16_sub_groups(param_group['params'])
for sub_group in sub_groups:
i = len(self.fp16_groups)
# push this group to list before modify
self.fp16_groups.append(sub_group)
self.sub_group_to_group_id[i] = j
#These are the list of the partitoned parameters
self.fp16_partitioned_groups.append(
[param.ds_tensor for param in self.fp16_groups[i]])
print_rank_0(
f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}"
)
# Record padding required to align group to world size (only applies to last rank)
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
padding = [p.padding_size() for p in self.fp16_groups[i]]
else:
padding = [0] * len(self.fp16_groups[i])
self.groups_padding.append(padding)
#not sure why apex was cloning the weights before flattening
#removing cloning here
see_memory_usage(f"Before Flattening param group {i}", force=False)
if not self.cpu_offload_params:
see_memory_usage(f"Before moving param group {i} to CPU",
force=False)
#move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.fp16_partitioned_groups[i])
see_memory_usage(f"After moving param group {i} to CPU", force=False)
#create flat buffer in CPU and move to GPU
self.fp16_partitioned_groups_flat.append(
flatten_dense_tensors_aligned(
self.fp16_partitioned_groups[i],
dist.get_world_size(group=self.dp_process_group)).cuda(
torch.cuda.current_device()))
see_memory_usage(
f"After flattening and moving param group {i} to GPU",
force=False)
else:
#Without the detach, seems like the flattening becomes part of the
#model graph causing errors downstream
self.fp16_partitioned_groups_flat.append(
flatten_dense_tensors_aligned(
self.fp16_partitioned_groups[i],
dist.get_world_size(
group=self.dp_process_group)).detach().pin_memory())
see_memory_usage(f"After Flattening param group {i}", force=False)
see_memory_usage(f"After Flattening param group {i}", force=False)
#set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(
self.fp16_partitioned_groups_flat[i],
self.fp16_partitioned_groups[i])
for partitioned_param, q in zip(self.fp16_partitioned_groups[i], updated_params):
partitioned_param.data = q.data
def _move_to_flat_buffer(self, src_list, flat_buffer):
start = 0
for src in src_list:
dest = flat_buffer.narrow(0, start, src.numel())
start = start + src.numel()
dest.data.copy_(src.data)
src.data = dest.data
def _create_fp16_partitions_with_defragmentation(self):
dist.barrier()
partition_id = dist.get_rank(group=self.dp_process_group)
if self.cpu_offload_params:
self.param_groups_fp16_flat_cpu_memory = []
for j, param_group in enumerate(self.optimizer.param_groups):
total_params = sum([p.ds_tensor.numel() for p in param_group['params']])
self.param_groups_fp16_flat_cpu_memory.append(
torch.empty(total_params,
dtype=torch.half,
pin_memory=True))
# loop to deal with groups
for j, param_group in enumerate(self.optimizer.param_groups):
sub_groups = self._create_fp16_sub_groups(param_group['params'])
flat_offset = 0
for sub_group in sub_groups:
i = len(self.fp16_groups)
# push this group to list before modify
self.fp16_groups.append(sub_group)
self.sub_group_to_group_id[i] = j
#These are the list of the partitoned parameters
self.fp16_partitioned_groups.append(
[param.ds_tensor for param in self.fp16_groups[i]])
print_rank_0(
f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}"
)
# Record padding required to align group to world size (only applies to last rank)
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
padding = [p.padding_size() for p in self.fp16_groups[i]]
else:
padding = [0] * len(self.fp16_groups[i])
self.groups_padding.append(padding)
#not sure why apex was cloning the weights before flattening
#removing cloning here
see_memory_usage(f"Before Flattening param group {i}", force=False)
if not self.cpu_offload_params:
see_memory_usage(f"Before moving param group {i} to CPU",
force=False)
#move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.fp16_partitioned_groups[i])
see_memory_usage(f"After moving param group {i} to CPU", force=False)
#create flat buffer in CPU and move to GPU
self.fp16_partitioned_groups_flat.append(
flatten_dense_tensors_aligned(
self.fp16_partitioned_groups[i],
dist.get_world_size(group=self.dp_process_group)).cuda(
torch.cuda.current_device()))
see_memory_usage(
f"After flattening and moving param group {i} to GPU",
force=False)
else:
total_elements = sum(
[t.numel() for t in self.fp16_partitioned_groups[i]])
fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[
j].narrow(0,
flat_offset,
total_elements)
self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat)
self._move_to_flat_buffer(self.fp16_partitioned_groups[i],
self.fp16_partitioned_groups_flat[i])
flat_offset += total_elements
see_memory_usage(f"After Flattening param group {i}", force=False)
def _create_fp32_partitions(self):
for i, tensor in enumerate(self.fp16_partitioned_groups_flat):
# a partition of the fp32 master weights that will be updated by this process
self.fp32_partitioned_groups_flat.append(
self.fp16_partitioned_groups_flat[i].to(
self.device).clone().float().detach())
element_size = self.fp32_partitioned_groups_flat[i].element_size()
num_elements = self.fp32_partitioned_groups_flat[i].numel()
self.fp32_partitioned_groups_flat[
i].requires_grad = True # keep this in case internal optimizer uses it
# Clear for on-the-fly population before the optimizer step
for param_group in self.optimizer.param_groups:
param_group['params'] = []
def _create_fp16_sub_groups(self, params_group):
params_group_numel = sum([param.ds_tensor.numel() for param in params_group])
sub_group_size = self.sub_group_size
if sub_group_size is None or sub_group_size >= params_group_numel:
return [params_group]
sub_groups = []
sub_group = []
local_sub_group_size = 0
for param in params_group:
sub_group.append(param)
local_sub_group_size += param.ds_tensor.numel()
if local_sub_group_size >= sub_group_size or id(param) == id(
params_group[-1]):
sub_groups.append(sub_group)
sub_group = []
local_sub_group_size = 0
return sub_groups
# def reset_ds_tensor(self):
# for name, param in self.module.named_parameters(recurse=True):
# assert hasattr(param,'ds_id'), "Parameters have not been converted to be Zero 3 compatible"
# assert (param.ds_status == ZeroParamStatus.NOT_AVAILABLE), "All the parameters must have been partitioned by now"
# param.ds_tensor.data = param.data
def setup_zero_stage3_hooks(self):
self.hierarchy = 0
self._register_hooks_recursively(self.module)
def persistent_parameters(self):
persistent_params = []
total_persistent_parameters = 0
for _, param in self.module.named_parameters(recurse=True):
if param.ds_numel < self.persistence_threshold:
param.ds_persist = True
persistent_params.append(param)
total_persistent_parameters += param.ds_numel
print_rank_0(
f'ZeRO 3: Total persistent parameters: {total_persistent_parameters}',
force=False)
return persistent_params
def _register_hooks_recursively(self, module, count=[0]):
my_count = count[0]
module.id = my_count
#print(f"{module.__class__} : {module.id}")
for child in module.children():
count[0] = count[0] + 1
self._register_hooks_recursively(child, count=count)
def _pre_forward_module_hook(module, *args):
self.pre_sub_module_forward_function(module)
def _post_forward_module_hook(module, *args):
self.post_sub_module_forward_function(module)
def _pre_backward_module_hook(module, inputs, output):
def _run_before_backward_function(sub_module):
if sub_module.applied_pre_backward is False:
self.pre_sub_module_backward_function(sub_module)
sub_module.applied_pre_backward = True
return _apply_to_tensors_only(module,
PreBackwardFunction,
_run_before_backward_function,
output)
#This is an alternate to doing _post_backward_module_hook
#it uses tensor.register_hook instead of using torch.autograd.Function
def _alternate_post_backward_module_hook(module, inputs):
module.ds_grads_remaining = 0
#print(f"Before Forward {module.__class__.__name__}")
def _run_after_backward_hook(*unused):
module.ds_grads_remaining = module.ds_grads_remaining - 1
if module.ds_grads_remaining == 0:
#print(f"After backward {module.__class__.__name__}")
self.post_sub_module_backward_function(module)
def _run_before_forward_function(input):
if input.requires_grad:
module.ds_grads_remaining += 1
return _apply_forward_and_backward_to_tensors_only(
module,
_run_before_forward_function,
_run_after_backward_hook,
inputs)
def _post_backward_module_hook(module, inputs):
module.ds_grads_remaining = 0
def _run_after_backward_function(sub_module):
if sub_module.ds_grads_remaining == 0:
self.post_sub_module_backward_function(sub_module)
return _apply_to_tensors_only(module,
PostBackwardFunction,
_run_after_backward_function,
inputs)
# Pre forward hook
module.register_forward_pre_hook(_pre_forward_module_hook)
# Post forward hook
module.register_forward_hook(_post_forward_module_hook)
# Pre backward hook
module.register_forward_hook(_pre_backward_module_hook)
# post backward hook
module.register_forward_pre_hook(_post_backward_module_hook)
def pre_sub_module_forward_function(self, sub_module):
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}",
force=False)
self.param_coordinator.record_trace(sub_module)
self.param_coordinator.fetch_sub_module(sub_module)
see_memory_usage(
f"Before sub module function {sub_module.__class__.__name__} after fetch",
force=False)
self.param_coordinator.prefetch_next_sub_modules(sub_module,
numel=self.prefetch_elements)
see_memory_usage(
f"Before sub module function {sub_module.__class__.__name__} after prefetch",
force=False)
self.param_coordinator.increment_step(sub_module)
def post_sub_module_forward_function(self, sub_module):
see_memory_usage(
f"After sub module function {sub_module.__class__.__name__} before release",
force=False)
self.param_coordinator.release_sub_module(sub_module)
see_memory_usage(
f"After sub module function {sub_module.__class__.__name__} after release",
force=False)
def pre_sub_module_backward_function(self, sub_module):
self.param_coordinator.record_trace(sub_module)
self.param_coordinator.fetch_sub_module(sub_module)
self.param_coordinator.prefetch_next_sub_modules(sub_module,
numel=self.prefetch_elements)
self.param_coordinator.increment_step(sub_module)
def post_sub_module_backward_function(self, sub_module):
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} before release",
force=False)
self.param_coordinator.release_sub_module(sub_module)
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} after release",
force=False)
def _release_ipg_buffers(self):
if self.contiguous_gradients:
self.ipg_buffer = None
if not self.cpu_offload:
self.grads_in_partition = None
self.grads_in_partition_offset = 0
def _optimizer_step(self, sub_group_id):
param_group_id = self.sub_group_to_group_id[sub_group_id]
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
fp16_param = self.fp16_partitioned_groups_flat[sub_group_id]
self.optimizer.param_groups[param_group_id]['params'] = [fp32_param]
self.optimizer.step()
self.optimizer.param_groups[param_group_id]['params'] = []
fp16_param.data.copy_(fp32_param.data)
def initialize_optimizer_states(self):
num_subgroups = len(self.fp16_groups)
largest_numel = max([t.numel() for t in self.fp16_partitioned_groups_flat])
gradient_dtype = self.fp32_partitioned_groups_flat[0].dtype
gradient_buffer = torch.zeros(int(largest_numel),
dtype=gradient_dtype,
device=self.device)
for i, group in enumerate(self.fp16_groups):
see_memory_usage(
f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups',
force=False)
num_elements = int(self.fp16_partitioned_groups_flat[i].numel())
if self.cpu_offload and not self.cpu_offload_use_pin_memory:
self.fp32_partitioned_groups_flat[i].grad = torch.zeros(
num_elements,
dtype=gradient_dtype,
device=self.device)
elif self.cpu_offload_use_pin_memory:
self.fp32_partitioned_groups_flat[i].grad = torch.zeros(
num_elements,
dtype=gradient_dtype,
device=self.device).pin_memory()
else:
self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(
0,
0,
num_elements)
self._optimizer_step(i)
see_memory_usage(
f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups',
force=False)
if not self.cpu_offload:
for group in self.fp32_partitioned_groups_flat:
group.grad = None
return
#########################################################################
#########################ZeRO Partition Gradients########################
#########################################################################
def get_first_param_index(self, group_id, param_group, partition_id):
for index, param in enumerate(param_group):
param_id = self.get_param_id(param)
if partition_id in self.param_to_partition_ids[group_id][param_id]:
return index
return None
def initialize_gradient_partitioning_data_structures(self):
total_partitions = dist.get_world_size(group=self.dp_process_group)
for i, param_group in enumerate(self.fp16_groups):
self.param_to_partition_ids[i] = {}
self.is_partition_reduced[i] = {}
self.total_grads_in_partition[i] = {}
self.remaining_grads_in_partition[i] = {}
self.is_grad_computed[i] = {}
self.grad_partition_insertion_offset[i] = {}
self.grad_start_offset[i] = {}
self.first_param_index_in_partition[i] = {}
for partition_id in range(total_partitions):
self.is_grad_computed[i][partition_id] = {}
self.grad_partition_insertion_offset[i][partition_id] = {}
self.grad_start_offset[i][partition_id] = {}
self.initialize_gradient_partition(i, param_group, partition_id)
self.is_partition_reduced[i][partition_id] = False
self.first_param_index_in_partition[i][
partition_id] = self.get_first_param_index(
i,
param_group,
partition_id)
def independent_gradient_partition_epilogue(self):
self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0)
self.reduce_ipg_grads()
self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0)
if self.overlap_comm:
self.reduction_stream.synchronize()
with torch.cuda.stream(self.reduction_stream):
self.partition_previous_reduced_grads()
# if dist.get_rank() == 0:
# logger.info("Params already reduced %s", self.params_already_reduced)
for i in range(len(self.params_already_reduced)):
self.params_already_reduced[i] = False
#in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad
#TODO: use a similar code path for both cpu_offload and non-cpu offload
if not self.cpu_offload:
for i, sub_group in enumerate(self.fp16_groups):
self.averaged_gradients[i] = [
torch.zeros_like(param.ds_tensor) if param.grad is None else
param.grad.data.narrow(0,
0,
param.ds_tensor.numel())
for param in sub_group
]
# self.averaged_gradients[i] = self.get_flat_partition(
# self.fp16_groups[i],
# 0,
# self.fp32_partitioned_groups_flat[i].numel(),
# return_tensor_list=True)
self._release_ipg_buffers()
see_memory_usage(f"End ipg_epilogue", force=False)
# resets all partition to no reduced
# sets remianing grads to the total number of grads in each partition
# set is grad computed to false for all grads in partition
def reset_partition_gradient_structures(self):
total_partitions = dist.get_world_size(group=self.dp_process_group)
for i, _ in enumerate(self.fp16_groups):
for partition_id in range(total_partitions):
self.is_partition_reduced[i][partition_id] = False
self.remaining_grads_in_partition[i][
partition_id] = self.total_grads_in_partition[i][partition_id]
for param_id in self.is_grad_computed[i][partition_id]:
self.is_grad_computed[i][partition_id][param_id] = False
def initialize_gradient_partition(self, i, param_group, partition_id):
def set_key_value_list(dictionary, key, value):
if key in dictionary:
dictionary[key].append(value)
else:
dictionary[key] = [value]
def increment_value(dictionary, key):
if key in dictionary:
dictionary[key] += 1
else:
dictionary[key] = 1
partition_size = self.partition_size[i]
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for param in param_group:
param_size = param.numel()
param_id = self.get_param_id(param)
if (current_index >= start_index and current_index < end_index):
set_key_value_list(self.param_to_partition_ids[i],
param_id,
partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][
param_id] = current_index - start_index
self.grad_start_offset[i][partition_id][param_id] = 0
elif start_index > current_index and start_index < (current_index +
param_size):
assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
set_key_value_list(self.param_to_partition_ids[i],
param_id,
partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][param_id] = 0
self.grad_start_offset[i][partition_id][param_id] = first_offset
current_index = current_index + param_size
def overlapping_partition_gradients_reduce_epilogue(self):
self.independent_gradient_partition_epilogue()
self.zero_grad()
def create_reduce_and_remove_grad_hooks(self):
print_rank_0(f'[Begin] Create gradient reduction hooks')
self.grad_accs = []
for i, param_group in enumerate(self.fp16_groups):
for param in param_group:
if param.requires_grad:
#print_rank_0(f" Before all gather {param.device}, {param.shape}")
# The hook must be created in un-partitioned parameter
param.all_gather()
#print(f"After all gather {param.device}, {param.shape}")
def wrapper(param, i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param, i)
grad_acc.register_hook(reduce_partition_and_remove_grads)
self.grad_accs.append(grad_acc)
#print(f"param grad fn {param.expand_as(param).grad_fn}")
wrapper(param, i)
# Partition the parameter after creating the hook
param.partition()
print_rank_0(f'[End] Create gradient reduction hooks')
def get_param_id(self, param):
unique_id = id(param)
return self.param_id[unique_id]
def report_ipg_memory_usage(self, tag, param_elems):
elem_count = self.elements_in_ipg_bucket + param_elems
percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size
see_memory_usage(
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}",
force=False)
###############Idependent Partition Gradient ########################
def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
#print_rank_0(f"Inside reduce ipg buckets. Param ID {param.ds_id}, ipg elements {self.elements_in_ipg_bucket}, reduce bucket size {self.reduce_bucket_size}", force=True)
if self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size:
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads",
param.ds_numel)
self.reduce_ipg_grads()
if self.contiguous_gradients and self.overlap_comm:
# Swap ipg_index between 0 and 1
self.ipg_index = 1 - self.ipg_index
self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads",
param.ds_numel)
param_id = self.get_param_id(param)
assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported"
# keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
if param.ds_numel > self.reduce_bucket_size:
self.extra_large_param_to_reduce = param
elif self.contiguous_gradients:
#print_rank_0("before new grad tensor move")
new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(
0,
self.elements_in_ipg_bucket,
param.ds_numel)
#print_rank_0("after new grad tensor move")
new_grad_tensor.copy_(param.grad.view(-1))
param.grad.data = new_grad_tensor.data.view_as(param.grad)
self.elements_in_ipg_bucket += param.ds_numel
self.grads_in_ipg_bucket.append(param.grad)
self.params_in_ipg_bucket.append((i, param, param_id))
self.report_ipg_memory_usage("End ipg_remove_grads", 0)
def gradient_reduction_w_predivide(self, tensor):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
tensor_to_allreduce = tensor
if self.allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
if self.postscale_gradients:
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.gradient_predivide_factor() != dp_world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor() /
dp_world_size)
else:
tensor_to_allreduce.div_(dp_world_size)
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
return tensor
def average_tensor(self, tensors, params_to_reduce):
with torch.cuda.stream(self.reduction_stream):
if not self.reduce_scatter:
for tensor in tensors:
self.gradient_reduction_w_predivide(tensor)
return
for tensor in tensors:
tensor.div_(dist.get_world_size(group=self.dp_process_group))
# reduction resulting with each rank only holding the gradient partition it owns
# This could either be a reduce scatter or a reduce op depending on how
# parameters are partitionied. The method is impelemnted by the
# DeepSpeed param extensions to the pytroch parameter, so its up to
# the extension to define what happens here
params_to_reduce[0].reduce_gradients_at_owner(
param_list=params_to_reduce,
hierarchy=self.param_coordinator.hierarchy)
def set_grad_positions(self):
for i, group in enumerate(self.fp16_groups):
current_offset = 0
for param in group:
param_id = self.get_param_id(param)
num_elements = param.ds_tensor.numel()
self.grad_position[param_id] = [
int(i),
int(current_offset),
int(num_elements)
]
#print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}")
current_offset += num_elements
def async_accumulate_grad_in_cpu_via_gpu(self, param, acc_grad_cpu_partition):
# copy to a preexisiting buffer to avoid memory allocation penalty
dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(
0,
0,
param.ds_tensor.numel())
if self.micro_step_id > 0:
dest_buffer.copy_(acc_grad_cpu_partition.view(-1), non_blocking=True)
param.grad.data.view(-1).add_(dest_buffer)
# at the boundary we will send 32bit directly
if not self.is_gradient_accumulation_boundary:
acc_grad_cpu_partition.data.copy_(param.grad.data.view(-1),
non_blocking=True)
def _constant_buffered_norm2(self, input, buffer_size=250000000):
norm = None
for part in input.view(-1).split(buffer_size):
if norm is None:
norm = part.data.double().norm(2)**2.0
else:
norm += part.data.double().norm(2)**2.0
return norm**0.5
def set_norm_for_param_grad_in_gpu(self, param):
param_id = self.get_param_id(param)
#self.norm_for_param_grads[param_id] = param.grad.data.double().norm(2)
#Using a more memory efficient version
self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(param.grad)
def update_overflow_tracker_for_param_grad(self, param):
#Credit to our user David Minn
if param.grad is not None:
if self.overlap_comm:
self.gpu_sum = self.gpu_sum + param.grad.data.float().sum()
elif self._has_inf_or_nan(param.grad.data):
self.local_overflow = True
def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor):
with torch.cuda.stream(self.copy_grad_stream):
param_id = self.get_param_id(param)
src_tensor = param.grad.view(-1).float()
#print(f"src_tensor {src_tensor.size()} and fp32 grad {fp32_grad_tensor.size()}")
fp32_grad_tensor.copy_(src_tensor, non_blocking=True)
param.grad = None
def complete_grad_norm_calculation_for_cpu_offload(self, params):
total_norm = 0.0
norm_type = 2.0
for p in params:
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_id = self.get_param_id(p)
if param_id in self.norm_for_param_grads.keys():
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item()**2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
def partition_previous_reduced_grads(self):
if not self.previous_reduced_grads:
return
if self.cpu_offload:
allocate_grads_in_partition = self.grads_in_partition is None\
and self.gradient_accumulation_steps > 1
else:
allocate_grads_in_partition = self.grads_in_partition is None
if allocate_grads_in_partition:
self.grads_in_partition = []
for i, group in enumerate(self.fp16_groups):
total_size = 0
for param_in_partition in group:
total_size += param_in_partition.ds_tensor.numel()
see_memory_usage(
f"group {i} before creating {total_size} reduced gradients into partition",
force=True)
if self.cpu_offload_use_pin_memory:
self.grads_in_partition.append(
torch.zeros(int(total_size),
dtype=torch.half,
device=self.device).pin_memory())
else:
self.grads_in_partition.append(
torch.zeros(int(total_size),
dtype=torch.half,
device=self.device))
see_memory_usage(
f"group {i} after creating {total_size} reduced gradients into partition",
force=True)
for param in self.previous_reduced_grads:
[i, dest_offset, num_elements] = self.grad_position[self.get_param_id(param)]
# self.debug_fp16_grads[i][self.get_param_id(param)] = (
# float(param.data.float().norm(2)),
# float(param.grad.data.float().norm(2)))
if self.cpu_offload:
param.partition_gradients(partition_buffers=self.temp_grad_gpu_buffer)
with torch.cuda.stream(self.copy_grad_stream):
self.reduction_stream.synchronize()
if self.gradient_accumulation_steps > 1:
# The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer
fp16_grad_tensor = self.grads_in_partition[i].narrow(
0,
dest_offset,
num_elements)
self.async_accumulate_grad_in_cpu_via_gpu(param, fp16_grad_tensor)
if self.is_gradient_accumulation_boundary:
self.set_norm_for_param_grad_in_gpu(param)
self.update_overflow_tracker_for_param_grad(param)
fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow(
0,
dest_offset,
num_elements)
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(
param,
fp32_grad_tensor)
else:
# The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer
fp16_grad_tensor = self.grads_in_partition[i].narrow(
0,
dest_offset,
num_elements)
param.partition_gradients(
partition_buffers=fp16_grad_tensor,
accumulate=True if self.micro_step_id > 0 else False)
self.previous_reduced_grads = []
def reduce_ipg_grads(self, extra_param=None):
if self.overlap_comm:
self.reduction_stream.synchronize()
with torch.cuda.stream(self.reduction_stream):
self.partition_previous_reduced_grads()
params_to_reduce = [param for i, param, param_id in self.params_in_ipg_bucket]
#print(f"Params in ipg bucket {self.params_in_ipg_bucket}")
#print(f"Reducing {[(param.ds_id, param.grad) for param in params_to_reduce]}")
#exit(0)
if self.contiguous_gradients:
reduction_list = [self.ipg_buffer[self.ipg_index]]
if self.extra_large_param_to_reduce is not None:
reduction_list.append(self.extra_large_param_to_reduce.grad)
self.extra_large_param_to_reduce = None
self.average_tensor(reduction_list, params_to_reduce)
else:
self.buffered_reduce_fallback(
None,
self.grads_in_ipg_bucket,
elements_per_buffer=self.elements_in_ipg_bucket)
for _, param, param_id in self.params_in_ipg_bucket:
self.params_already_reduced[param_id] = True
self.previous_reduced_grads = params_to_reduce
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.elements_in_ipg_bucket = 0
#####################################################################
def reduce_ready_partitions_and_remove_grads(self, param, i):
#print(f"Backward {param.ds_id}")
self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
def zero_reduced_gradients(self, partition_id, i):
def are_all_related_partitions_reduced(params_id):
for partition_id in self.param_to_partition_ids[i][params_id]:
if not self.is_partition_reduced[i][partition_id]:
return False
return True
for params_id in self.is_grad_computed[i][partition_id]:
if are_all_related_partitions_reduced(params_id):
self.param_dict[params_id].grad = None
def flatten_and_print(self, message, tensors, start=0, n=5):
flatten_tensor = _flatten_dense_tensors(tensors)
def print_func():
logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n))
self.sequential_execution(print_func, message)
def get_grads_to_reduce(self, i, partition_id):
def get_reducable_portion(key):
grad = self.param_dict[key].grad
total_elements = grad.numel()
start = self.grad_start_offset[i][partition_id][key]
num_elements = min(
total_elements - start,
self.partition_size[i] -
self.grad_partition_insertion_offset[i][partition_id][key])
if not pg_correctness_test:
if num_elements == total_elements:
return grad
else:
return grad.contiguous().view(-1).narrow(0,
int(start),
int(num_elements))
else:
if num_elements == total_elements:
return grad.clone()
else:
return grad.clone().contiguous().view(-1).narrow(
0,
int(start),
int(num_elements))
grads_to_reduce = []
for key in self.is_grad_computed[i][partition_id]:
grad = get_reducable_portion(key)
grads_to_reduce.append(grad)
return grads_to_reduce
def sequential_execution(self, function, message, group=None):
if group is None:
group = self.dp_process_group
if dist.get_rank(group=group) == 0:
logger.info(message)
for id in range(dist.get_world_size(group=group)):
if id == dist.get_rank(group=group):
function()
dist.barrier(group=group)
def set_none_gradients_to_zero(self, i, partition_id):
for param_id in self.is_grad_computed[i][partition_id]:
param = self.param_dict[param_id]
if param.grad is None:
param.grad = torch.zero_like(param)
######################Reduction Related Methods##############################
def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None):
rank = None
tensor = flatten(bucket)
tensor_to_allreduce = tensor
if pg_correctness_test:
allreduce_always_fp32 = True
if allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group))
if rank is None:
# "All Reducing"
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
else:
global_rank = _get_global_rank(self.dp_process_group, rank)
dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group)
if allreduce_always_fp32 and tensor is not tensor_to_allreduce:
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
tensor.copy_(tensor_to_allreduce)
return tensor
# if rank is specified do a reduction instead of an allreduce
def allreduce_and_copy(self, small_bucket, rank=None, log=None):
with torch.cuda.stream(self.reduction_stream):
allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log)
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)):
buf.copy_(synced)
def allreduce_no_retain(self,
bucket,
numel_per_bucket=500000000,
rank=None,
log=None):
small_bucket = []
numel = 0
for tensor in bucket:
small_bucket.append(tensor)
numel = numel + tensor.numel()
if numel > numel_per_bucket:
self.allreduce_and_copy(small_bucket, rank=rank, log=None)
small_bucket = []
if len(small_bucket) > 0:
self.allreduce_and_copy(small_bucket, rank=rank, log=log)
# allows using reduction of gradients instead of using all_reduce
def buffered_reduce_fallback(self,
rank,
grads,
elements_per_buffer=500000000,
log=None):
split_buckets = split_half_float_double(grads)
for i, bucket in enumerate(split_buckets):
self.allreduce_no_retain(bucket,
numel_per_bucket=elements_per_buffer,
rank=rank,
log=log)
#############################################################################
#############################################################################
#############################################################################
# views the tensor as multiple partitions and returns
# those partitions
def get_data_parallel_partitions(self, tensor):
partitions = []
dp = dist.get_world_size(group=self.dp_process_group)
dp_id = dist.get_rank(group=self.dp_process_group)
total_num_elements = tensor.numel()
base_size = total_num_elements // dp
remaining = total_num_elements % dp
start = 0
for id in range(dp):
partition_size = base_size
if id < remaining:
partition_size = partition_size + 1
partitions.append(tensor.narrow(0, start, partition_size))
start = start + partition_size
return partitions
def get_partition_info(self, tensor_list, partition_size, partition_id):
params_in_partition = []
params_not_in_partition = []
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for tensor in tensor_list:
tensor_size = tensor.numel()
if (current_index >= start_index and current_index < end_index):
params_in_partition.append(tensor)
elif start_index > current_index and start_index < (current_index +
tensor_size):
params_in_partition.append(tensor)
assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
else:
params_not_in_partition.append(tensor)
current_index = current_index + tensor_size
return params_in_partition, params_not_in_partition, first_offset
def zero_grad(self, set_grads_to_None=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def _model_parallel_all_reduce(self, tensor, op):
""" Perform all reduce within model parallel group, if any.
"""
if self.model_parallel_group is None:
torch.distributed.all_reduce(tensor=tensor, op=op)
else:
torch.distributed.all_reduce(tensor=tensor,
op=op,
group=self.model_parallel_group)
def get_grad_norm_direct(self, gradients, params, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=self.dp_process_group)
# Take max across all GPUs.
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
# if dist.get_rank() == 0:
# logger.info(f"Total Norm begining {total_norm}")
for g, p in zip(gradients, params):
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
# creates a flat fused tensor from the tensor list starting at the first_offset
# in the first tensor of the list. If there are not enough elements in the tensor
# list then the flat tensor will be padded with zeros
def get_flat_partition(self,
tensor_list,
first_offset,
partition_size,
return_tensor_list=False):
flat_tensor_list = []
current_size = 0
for i, tensor in enumerate(tensor_list):
if tensor.grad is None:
tensor.grad = torch.zeros_like(tensor)
tensor = tensor.grad
num_elements = tensor.numel()
tensor_offset = 0
# we need to offset to get to the right element
if i == 0 and first_offset > 0:
tensor_offset = first_offset
num_elements = num_elements - tensor_offset
# we dont need all elements of the tensor
if num_elements > (partition_size - current_size):
num_elements = partition_size - current_size
# we need a narrow view of the tensor based on the tensor offset and number of elements that
# we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel():
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
0,
int(tensor_offset),
int(num_elements)))
else:
flat_tensor_list.append(tensor)
current_size = current_size + num_elements
# this means its the last partition and does not align with the dp boundary. We need to pad before flattening
if current_size < partition_size:
flat_tensor_list.append(
torch.zeros(int(partition_size - current_size),
dtype=tensor_list[0].dtype,
device=tensor_list[0].device))
if return_tensor_list:
return flat_tensor_list
return _flatten_dense_tensors(flat_tensor_list)
def free_grad_in_param_list(self, param_list):
for p in param_list:
p.grad = None
def reset_cpu_buffers(self):
self.norm_for_param_grads = {}
self.local_overflow = False
def log_timers(self, timer_names):
self.timers.log(names=list(timer_names))
def start_timers(self, timer_names):
for name in timer_names:
self.timers(name).start()
def stop_timers(self, timer_names):
for name in timer_names:
self.timers(name).stop()
def old_step(self, closure=None):
"""
Not supporting closure.
"""
self.micro_step_id = INITIAL_MICRO_STEP_ID
# if self.cpu_offload:
# torch.cuda.current_stream().wait_stream(self.migration_stream)
print_rank_0(f"Inside Step function")
see_memory_usage(f"In step before checking overflow", force=False)
print_rank_0("Finished Tracing at Beginning of Step")
self.param_coordinator.hierarchy = 0
self.param_coordinator.finish_tracing(print_trace=True)
self.param_coordinator.reset_step()
print_rank_0("Finished Tracing at Beginning of Step")
# First compute norm for all group so we know if there is overflow
self.check_overflow()
timers = self.timers
OPTIMIZER_STEP = 'optimizer_step'
OPTIMIZER_FP16_UPDATE = 'optimizer_fp16_update'
OPTIMIZER_FP32_GRADIENT = 'optimizer_fp32_gradient'
timer_names = [OPTIMIZER_STEP, OPTIMIZER_FP16_UPDATE, OPTIMIZER_FP32_GRADIENT]
prev_scale = self.loss_scale
self._update_scale(self.overflow)
if self.overflow:
see_memory_usage('After overflow before clearing gradients', force=False)
self.zero_grad()
if self.cpu_offload:
self.reset_cpu_buffers()
else:
self.averaged_gradients = {}
see_memory_usage('After overflow after clearing gradients', force=False)
logger.info(
"[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(dist.get_rank(),
prev_scale,
self.loss_scale))
self.start_timers(timer_names)
self.stop_timers(timer_names)
return
norm_groups = []
single_partition_grad_groups = []
skip = False
partition_id = dist.get_rank(group=self.dp_process_group)
debug_fp32_grads = [{} for _ in self.fp16_groups]
self.start_timers([OPTIMIZER_FP32_GRADIENT])
for i, group in enumerate(self.fp16_groups):
if self.cpu_offload:
norm_groups.append(
self.complete_grad_norm_calculation_for_cpu_offload(
self.fp16_groups[i]))
single_grad_partition = self.fp32_partitioned_groups_flat[i].grad
else:
norm_groups.append(
self.get_grad_norm_direct(self.averaged_gradients[i],
self.fp16_groups[i]))
# free gradients for all the prameters that are not updated by this process
# self.free_grad_in_param_list(self.params_not_in_partition[i])
# create a flat gradients for parameters updated by this process
# If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors
single_grad_partition = _flatten_dense_tensors(
self.averaged_gradients[i]).to(
self.fp32_partitioned_groups_flat[i].dtype)
assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[i].numel(), \
"averaged gradients have different number of elements that partition size {} {} {} {}".format(
single_grad_partition.numel(), self.partition_size[i], i, partition_id)
self.fp32_partitioned_groups_flat[i].grad = single_grad_partition
# release all the gradient since we have already created a necessary copy in dp_grad_partition
self.zero_grad()
self.averaged_gradients[i] = None
single_partition_grad_groups.append(single_grad_partition)
debug_fp32_grads[i] = [
(t.clone().detach(),
t) for t in _unflatten_dense_tensors(single_grad_partition,
group)
]
self.stop_timers([OPTIMIZER_FP32_GRADIENT])
print(f"Norm groups: {norm_groups}")
self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
#self.dump_pre_step_gradients(debug_fp32_grads)
self.start_timers([OPTIMIZER_STEP])
self.optimizer.step()
self.stop_timers([OPTIMIZER_STEP])
# get rid of the fp32 gradients. Not needed anymore
if not self.cpu_offload:
for group in self.fp32_partitioned_groups_flat:
group.grad = None
self.start_timers([OPTIMIZER_FP16_UPDATE])
for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat):
fp16_partitions.data.copy_(fp32_partition.data)
self.stop_timers([OPTIMIZER_FP16_UPDATE])
print(
f"fp16 groups norm : {[group_flat.norm() for group_flat in self.fp16_partitioned_groups_flat]}"
)
if self.cpu_offload:
self.reset_cpu_buffers()
# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
#for p in self.fp16_groups[i]:
# p.data=p.ds_tensor
updated_params = _unflatten_dense_tensors(
self.fp16_partitioned_groups_flat[i],
self.fp16_partitioned_groups[i])
for partitioned_param, q in zip(self.fp16_partitioned_groups[i], updated_params):
# print(f"Grad fn: {p.grad_fn}")
# p.data = torch.ones(1).half().cuda()
partitioned_param.data = q.data
#Gathering persisting parameters
self.persistent_parameters[0].all_gather(self.persistent_parameters)
#self.dump_post_step_gradients()
self.debug_fp16_grads = [{} for _ in self.fp16_groups]
if self.cpu_offload:
self.reset_cpu_buffers()
self.log_timers(timer_names)
see_memory_usage('After zero_optimizer step', force=False)
print_rank_0(f"------------------Finishing Step-----------------------",
force=True)
return
def _pre_step(self):
self.micro_step_id = INITIAL_MICRO_STEP_ID
print_rank_0(f"Inside Step function")
see_memory_usage(f"In step before checking overflow", force=False)
print_rank_0("Finished Tracing at Beginning of Step")
self.param_coordinator.hierarchy = 0
self.param_coordinator.finish_tracing(print_trace=True)
self.param_coordinator.reset_step()
print_rank_0("Finished Tracing at Beginning of Step")
def _get_norm_groups(self):
norm_groups = []
for i, group in enumerate(self.fp16_groups):
if self.cpu_offload:
norm_groups.append(
self.complete_grad_norm_calculation_for_cpu_offload(
self.fp16_groups[i]))
else:
norm_groups.append(
self.get_grad_norm_direct(self.averaged_gradients[i],
self.fp16_groups[i]))
return norm_groups
def _prepare_fp32_grad_for_sub_group(self, sub_group_id):
partition_id = dist.get_rank(group=self.dp_process_group)
single_grad_partition = _flatten_dense_tensors(
self.averaged_gradients[sub_group_id]).to(
self.fp32_partitioned_groups_flat[sub_group_id].dtype)
assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \
"averaged gradients have different number of elements that partition size {} {} {} {}".format(
single_grad_partition.numel(), self.partition_size[sub_group_id], sub_group_id, partition_id)
self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition
# release all the gradient since we have already created a necessary copy in dp_grad_partition
self.zero_grad()
self.averaged_gradients[sub_group_id] = None
def _prepare_sub_group(self, sub_group_id, timer_names=set()):
see_memory_usage(f'Before prepare optimizer sub group {sub_group_id}',
force=False)
if not self.cpu_offload:
self._prepare_fp32_grad_for_sub_group(sub_group_id)
see_memory_usage(f'After prepare optimizer sub group {sub_group_id}',
force=False)
def _release_sub_group(self, sub_group_id, timer_names=set()):
see_memory_usage(f'Before release optimizer sub group {sub_group_id}',
force=False)
# get rid of the fp32 gradients. Not needed anymore
if not self.cpu_offload:
self.fp32_partitioned_groups_flat[sub_group_id].grad = None
see_memory_usage(f'After release optimizer sub group {sub_group_id}',
force=False)
def _unflatten_partitioned_parameters(self, sub_group_id):
updated_params = _unflatten_dense_tensors(
self.fp16_partitioned_groups_flat[sub_group_id],
self.fp16_partitioned_groups[sub_group_id])
for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params):
partitioned_param.data = q.data
def _overflow_clean_up(self, prev_scale):
see_memory_usage('After overflow before clearing gradients', force=False)
self.zero_grad()
if self.cpu_offload:
self.reset_cpu_buffers()
else:
self.averaged_gradients = {}
see_memory_usage('After overflow after clearing gradients', force=False)
if torch.distributed.get_rank() == 0:
logger.info(
"[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(dist.get_rank(),
prev_scale,
self.loss_scale))
def _overflow_check_and_loss_scale_update(self):
# First compute norm for all group so we know if there is overflow
self.check_overflow()
#loss scaling related computation
prev_scale = self.loss_scale
self._update_scale(self.overflow)
if self.overflow:
self._overflow_clean_up(prev_scale)
return self.overflow
def _post_step(self, timer_names=set()):
if self.cpu_offload:
self.reset_cpu_buffers()
#Gathering persisting parameters
self.persistent_parameters[0].all_gather(self.persistent_parameters)
self.log_timers(timer_names)
see_memory_usage('After zero_optimizer step', force=True)
print_rank_0(f"------------------Finishing Step-----------------------")
def step(self, closure=None):
"""
Not supporting closure.
"""
self._pre_step()
#checks for overflow, adjust the loss scale accordingly
if self._overflow_check_and_loss_scale_update():
return
norm_groups = self._get_norm_groups()
timers = self.timers
timer_names = set()
timer_names.add('optimizer_step')
self.start_timers(['optimizer_step'])
#update parameters one sub group at a time
for sub_group_id, group in enumerate(self.fp16_groups):
#prepare optimizer states, gradients and fp32 parameters for update
self._prepare_sub_group(sub_group_id, timer_names)
#scale the fp32 gradients
self.unscale_and_clip_grads(sub_group_id, norm_groups)
#apply the optimizer step on the sub group and copy fp32 parameters to fp16
self._optimizer_step(sub_group_id)
#release memory or swap out optimizer states of fp32 parameters
self._release_sub_group(sub_group_id, timer_names)
#unflatten fp16 parameter subgroup
self._unflatten_partitioned_parameters(sub_group_id)
self.stop_timers(['optimizer_step'])
self._post_step(timer_names)
return
def dump_pre_step_gradients(self, debug_fp32_grads):
# Dump gradient norms for debbuging
for i, _ in enumerate(self.fp16_groups):
print(f'Pre-Step Dump Norms for Group {i} FP16P, FP16G, FP32G, FP32GUC')
for fp16_param, fp32_grad in zip(self.fp16_groups[i], debug_fp32_grads[i]):
param_id = self.get_param_id(fp16_param)
fp16_grad_norm = self.debug_fp16_grads[i][param_id]
fp32_grad_norm = [float(t.data.float().norm(2)) for t in fp32_grad]
norm_list = [fp16_grad_norm, fp32_grad_norm]
print(f'Pre-Step Norms {i} {param_id} = {norm_list}')
def dump_post_step_gradients(self):
# Dump gradient norms for debbuging
for i, group in enumerate(self.fp16_groups):
print(
f'Post-Step Dump Norms for Group {i} FP16P, FP16DS, FP16FLAT, FP32FLAT')
unflat_fp16 = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
unflat_fp32 = _unflatten_dense_tensors(self.fp32_partitioned_groups_flat[i],
self.fp16_groups[i])
for j, p in enumerate(self.fp16_groups[i]):
param_id = self.get_param_id(p)
param_norm = float(p.data.float().norm(2))
ds_norm = float(p.ds_tensor.data.float().norm(2))
unflat_norm = [
float(t.data.float().norm(2))
for t in [unflat_fp16[j],
unflat_fp32[j]]
]
norm_list = [param_norm, ds_norm] + unflat_norm
print(f'Post-Step Norms {i} {param_id} = {norm_list}')
def unscale_and_clip_grads(self, sub_group_id, norm_groups):
grad_groups_flat = [self.fp32_partitioned_groups_flat[sub_group_id].grad]
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)
# compute combined scale factor for this group
combined_scale = self.loss_scale
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.loss_scale
for grad in grad_groups_flat:
if isinstance(grad, list):
sub_partitions = grad
for g in sub_partitions:
g.data.mul_(1. / combined_scale)
else:
grad.data.mul_(1. / combined_scale)
def _check_overflow(self, partition_gradients=True):
self.overflow = self.has_overflow(partition_gradients)
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params, is_grad_list=False):
for p in params:
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
return True
return False
def has_overflow_partitioned_grads_serial(self):
for i in range(len(self.fp16_groups)):
for j, grad in enumerate(self.averaged_gradients[i]):
if grad is not None and self._has_inf_or_nan(grad.data, j):
return True
return False
def has_overflow(self, partition_gradients=True):
if partition_gradients:
if self.overlap_comm:
self.local_overflow = self._has_inf_or_nan(self.gpu_sum)
self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda()
overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial(
)
#overflow = self.has_overflow_partitioned_grads_serial()
overflow_gpu = torch.cuda.ByteTensor([overflow])
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=self.dp_process_group)
else:
params = []
for group in self.fp16_groups:
for param in group:
params.append(param)
overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients)
overflow_gpu = torch.cuda.ByteTensor([overflow])
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
self._model_parallel_all_reduce(tensor=overflow_gpu,
op=torch.distributed.ReduceOp.MAX)
overflow = overflow_gpu[0].item()
return bool(overflow)
# `x` is a torch.Tensor
@staticmethod
def _has_inf_or_nan(x, j=None):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
return False
def backward(self, loss, retain_graph=False):
"""
:attr:`backward` performs the following steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
self.micro_step_id += 1
print_rank_0(
f"Total fully available parameters {self.param_coordinator.total_available_parameter_numel}"
)
see_memory_usage(f"Before backward", force=False)
if self.contiguous_gradients:
self.ipg_buffer = []
buf_0 = torch.empty(self.reduce_bucket_size,
dtype=torch.half,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_0)
# Use double buffers to avoid data access conflict when overlap_comm is enabled.
if self.overlap_comm:
buf_1 = torch.empty(self.reduce_bucket_size,
dtype=torch.half,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_1)
self.ipg_index = 0
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
'''Partitioning Parameters that were not partitioned
Usually if parameters of modules whose input parameters do not require
grad computation do not trigger post call and will therefore will remain unpartitioned '''
self._partition_all_parameters()
def _partition_all_parameters(self):
for name, param in self.module.named_parameters(recurse=True):
self.param_coordinator.release_and_reset_parameter(param)
def check_overflow(self, partition_gradients=True):
self._check_overflow(partition_gradients)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
def _get_lean_tensors(self, padded_flattened_tensor, group_tensors, paddings):
# Remove paddings from flattened tensor
individual_tensors = _unflatten_dense_tensors(padded_flattened_tensor,
group_tensors)
lean_lengths = [t.numel() - pad for t, pad in zip(group_tensors, paddings)]
lean_tensors = [t[:len] for t, len in zip(individual_tensors, lean_lengths)]
#logger.info(f'rank {dist.get_rank()}: lean_tensors = {[t.numel() for t in lean_tensors]}')
return lean_tensors
#TODO REVISIT this for stage 3
def get_lean_optimizer_state(self):
# Return optimizer states after removing paddings.
# This method assumes that each param group contains a single flattened tensor.
optimizer_groups_state = []
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
lean_state = {}
for key, value in self.optimizer.state[p].items():
if torch.is_tensor(value):
padded_lens = [t.numel() for t in self.fp16_partitioned_groups[i]]
lean_state[key] = self._get_lean_tensors(
value,
self.fp16_partitioned_groups[i],
self.groups_padding[i])
lean_flat_len = sum([t.numel() for t in lean_state[key]])
else:
lean_state[key] = value
optimizer_groups_state.append(lean_state)
return optimizer_groups_state
def get_groups_without_padding(self, groups_with_padding):
# Return group tensor after removing paddings added for alignment to DP world size.
groups_without_padding = []
for i, group in enumerate(groups_with_padding):
lean_group = self._get_lean_tensors(group,
self.fp16_partitioned_groups[i],
self.groups_padding[i])
groups_without_padding.append(lean_group)
return groups_without_padding
def _set_fp32_optimizer_param_groups(self):
for sub_group_id, _ in enumerate(self.fp16_groups):
param_group_id = self.sub_group_to_group_id[sub_group_id]
self.optimizer.param_groups[param_group_id]['params'] = [
self.fp32_partitioned_groups_flat[sub_group_id]
]
def _clear_fp32_optimizer_param_groups(self):
for sub_group_id, _ in enumerate(self.fp16_groups):
param_group_id = self.sub_group_to_group_id[sub_group_id]
self.optimizer.param_groups[param_group_id]['params'] = []
def _rigid_state_dict(self):
state_dict = {}
state_dict['zero_stage'] = ZERO_OPTIMIZATION_WEIGHTS
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['partition_count'] = self.partition_count
self._set_fp32_optimizer_param_groups()
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict['fp32_flat_groups'] = self.fp32_partitioned_groups_flat
self._clear_fp32_optimizer_param_groups()
return state_dict
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
if self.elastic_checkpoint:
raise NotImplementedError(
"ZeRO-3 does not yet support elastic checkpointing, please disable for now."
)
return self._rigid_state_dict()
# Restore base optimizer fp32 weights from checkpoint by:
# 1) Merging fp32 weights from checkpoints of all partitions
# 2) Extracting fp32 weights for current partition from merged weights
# 3) Using extracted weights to update base optimizer weights directly.
def _restore_from_fp32_weights(self, all_state_dict):
flat_local_partition = []
for i in range(len(self.fp32_partitioned_groups_flat)):
merged_partitions = [sd['fp32_groups'][i] for sd in all_state_dict]
flat_local_partition.append(self._get_flattened_partition(merged_partitions))
for current, saved in zip(self.fp32_partitioned_groups_flat, flat_local_partition):
current.data.copy_(saved.data)
# Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self):
for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat):
fp32_partition.data.copy_(fp16_partitions.data)
# Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
self._restore_from_fp16_weights()
# Extract flattened partion for current rank from all partitions
def _get_flattened_partition(self, all_partition_states):
partition_id = dist.get_rank(group=self.dp_process_group)
alignment = dist.get_world_size(group=self.dp_process_group)
param_partitions = [[] for _ in range(len(all_partition_states[0]))]
for i, partition in enumerate(all_partition_states):
for j, param in enumerate(partition):
param_partitions[j].append(param)
local_state_partitions = []
for param_index, param_slices in enumerate(param_partitions):
flattened_merged_tensor = flatten_dense_tensors_aligned(
param_slices,
alignment)
new_partitions = self.get_data_parallel_partitions(flattened_merged_tensor)
local_state_partitions.append(new_partitions[partition_id])
if torch.is_tensor(local_state_partitions[0]):
return flatten_dense_tensors_aligned(local_state_partitions, alignment)
# Assume non-tensor states are not partitioned and equal across ranks, so return first one
return local_state_partitions[0]
# Restore base optimizer state from checkpoint by
# 1) Merging optimizer state from checkpoints of all partitions
# 2) Extracting optimizer state for current partition from the merged state
# 3) Using the extracted value to directly update the base optimizer.
def _restore_base_optimizer_state(self, all_state_dict):
base_optimizer_group_states = []
for i in range(len(self.optimizer.param_groups)):
partition_states = {}
all_partition_group_states = [
sd['base_optimizer_state'][i] for sd in all_state_dict
]
for key in all_partition_group_states[0].keys():
all_partition_states = [
all_states[key] for all_states in all_partition_group_states
]
partition_states[key] = self._get_flattened_partition(
all_partition_states)
base_optimizer_group_states.append(partition_states)
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
for key, saved in base_optimizer_group_states[i].items():
if torch.is_tensor(self.optimizer.state[p][key]):
self.optimizer.state[p][key].data.copy_(saved.data)
else:
self.optimizer.state[p][key] = saved
def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
if load_optimizer_states:
self._set_fp32_optimizer_param_groups()
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
self._clear_fp32_optimizer_param_groups()
# restore fp32 partitions
for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict['fp32_flat_groups']):
curr_param.data.copy_(saved_param.data)
# restore fp16 partitions from fp32
for sub_group_id in range(len(self.fp32_partitioned_groups_flat)):
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
fp16_param = self.fp16_partitioned_groups_flat[sub_group_id]
fp16_param.data.copy_(fp32_param.data)
# update fp16 unflattened params
for sub_group_id in range(len(self.fp16_partitioned_groups_flat)):
updated_params = _unflatten_dense_tensors(
self.fp16_partitioned_groups_flat[sub_group_id],
self.fp16_partitioned_groups[sub_group_id])
for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params):
partitioned_param.data = q.data
# TODO: Support different/changing load/save DP degree.
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
r"""Loading a ZeRO checkpoint
Arguments:
state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
Note that the number of saved partitions may differ from number of loading partitions to support
changing GPU count, specifically DP world size, between saving and loading checkpoints.
load_optimizer_states: Boolean indicating whether or not to load base optimizer states
load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32
copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss).
"""
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
if self.elastic_checkpoint:
raise NotImplementedError(
"ZeRO-3 does not yet support elastic checkpointing, please disable for now."
)
else:
self._rigid_load_state_dict(
state_dict_list[dist.get_rank(group=self.dp_process_group)],
load_optimizer_states=load_optimizer_states)
self.persistent_parameters[0].partition(self.persistent_parameters)
self.persistent_parameters[0].all_gather(self.persistent_parameters)
def save_checkpoint_prologue(self):
self._partition_all_parameters()
def save_checkpoint_epilogue(self):
self.persistent_parameters[0].all_gather(self.persistent_parameters)
def _handle_overflow(cpu_sum, x, i):
import math
rank = torch.distributed.get_rank()
if rank == 0:
t_i = -1
for v_i, v in enumerate(x.data.contiguous().view(-1)):
if not math.isfinite(float(v)):
t_i = v_i
break
logger.info(
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
)
import torch
from deepspeed.runtime.zero.contiguous_memory_allocator import ContiguousMemoryAllocator
def test1():
mem = ContiguousMemoryAllocator(1024, torch.half, 'cpu')
mem.print_allocation(resolution=100)
a1 = mem.allocate_tensor(64).mul_(0.0).add_(1.0)
mem.print_allocation(resolution=100)
mem.release_tensor(a1)
mem.print_allocation(resolution=100)
a2 = mem.allocate_tensor(64).mul_(0.0).add_(2.0)
a3 = mem.allocate_tensor(256).mul_(0.0).add_(3.0)
a4 = mem.allocate_tensor(128).mul_(0.0).add_(4.0)
mem.print_allocation(resolution=100)
mem.release_tensor(a3)
mem.print_allocation(resolution=100)
a5 = mem.allocate_tensor(64).mul_(0.0).add_(5.0)
a6 = mem.allocate_tensor(256).mul_(0.0).add_(6.0)
a7 = mem.allocate_tensor(128).mul_(0.0).add_(7.0)
mem.print_allocation(resolution=100)
a8 = mem.allocate_tensor(256).mul_(0.0).add_(8.0)
a9 = mem.allocate_tensor(128).mul_(0.0).add_(9.0)
mem.print_allocation(resolution=100)
mem.release_tensor(a9)
mem.release_tensor(a6)
mem.release_tensor(a2)
mem.release_tensor(a5)
a10 = mem.allocate_tensor(512).mul_(0.0).add_(10.0)
mem.print_allocation(resolution=100)
#print(f"a4:{a4}")
#print(f"a7:{a7}")
#print(f"a8:{a8}")
#print(f"a10:{a10}")
assert (a4.norm() + a7.norm() + a8.norm() + a10.norm()).item() == 474.50, "Test failed"
def test2():
mem = ContiguousMemoryAllocator(512, torch.half, 'cpu')
a1 = mem.allocate_tensor(64).mul_(0.0).add_(1.0)
a2 = mem.allocate_tensor(64).mul_(0.0).add_(2.0)
a3 = mem.allocate_tensor(64).mul_(0.0).add_(3.0)
a4 = mem.allocate_tensor(64).mul_(0.0).add_(4.0)
a5 = mem.allocate_tensor(64).mul_(0.0).add_(5.0)
a6 = mem.allocate_tensor(64).mul_(0.0).add_(6.0)
a7 = mem.allocate_tensor(64).mul_(0.0).add_(7.0)
a8 = mem.allocate_tensor(64).mul_(0.0).add_(8.0)
mem.release_tensor(a2)
mem.release_tensor(a4)
mem.release_tensor(a6)
mem.release_tensor(a8)
mem.print_allocation(resolution=100)
a9 = mem.allocate_tensor(128).mul_(0.0).add_(9.0)
a10 = mem.allocate_tensor(64).mul_(0.0).add_(10.0)
a11 = mem.allocate_tensor(64).mul_(0.0).add_(11.0)
mem.release_tensor(a1)
mem.release_tensor(a5)
mem.print_allocation(resolution=100)
a12 = mem.allocate_tensor(128).mul_(0.0).add_(12.0)
mem.print_allocation(resolution=100)
print(f"a7:{a7}")
print(f"a9:{a9}")
print(f"a10:{a10}")
print(f"a11:{a11}")
print(f"a12:{a12}")
assert (a7.norm() + a9.norm() + a10.norm() + a11.norm() + a12.norm()) == 460.75, "TestFailed"
test1()
test2()
......@@ -39,7 +39,8 @@ except ImportError:
def is_zero_supported_optimizer(optimizer):
print(
f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}'
)
if dist.get_rank() == 0:
print(
f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}'
)
return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS
......@@ -10,7 +10,7 @@ RUN mkdir -p ${STAGE_DIR}
# Installation/Basic Utilities
##############################################################################
RUN apt-get update && \
apt-get install -y --no-install-recommends \
apt-get install -y --no-install-recommends \
software-properties-common build-essential autotools-dev \
nfs-common pdsh \
cmake g++ gcc \
......@@ -23,9 +23,9 @@ RUN apt-get update && \
# Installation Latest Git
##############################################################################
RUN add-apt-repository ppa:git-core/ppa -y && \
apt-get update && \
apt-get install -y git && \
git --version
apt-get update && \
apt-get install -y git && \
git --version
##############################################################################
# Client Liveness & Uncomment Port 22 for SSH Daemon
......@@ -33,7 +33,7 @@ RUN add-apt-repository ppa:git-core/ppa -y && \
# Keep SSH client alive from server side
RUN echo "ClientAliveInterval 30" >> /etc/ssh/sshd_config
RUN cp /etc/ssh/sshd_config ${STAGE_DIR}/sshd_config && \
sed "0,/^#Port 22/s//Port 22/" ${STAGE_DIR}/sshd_config > /etc/ssh/sshd_config
sed "0,/^#Port 22/s//Port 22/" ${STAGE_DIR}/sshd_config > /etc/ssh/sshd_config
##############################################################################
# Mellanox OFED
......@@ -41,11 +41,11 @@ RUN cp /etc/ssh/sshd_config ${STAGE_DIR}/sshd_config && \
ENV MLNX_OFED_VERSION=4.6-1.0.1.1
RUN apt-get install -y libnuma-dev
RUN cd ${STAGE_DIR} && \
wget -q -O - http://www.mellanox.com/downloads/ofed/MLNX_OFED-${MLNX_OFED_VERSION}/MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64.tgz | tar xzf - && \
cd MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64 && \
./mlnxofedinstall --user-space-only --without-fw-update --all -q && \
cd ${STAGE_DIR} && \
rm -rf ${STAGE_DIR}/MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64*
wget -q -O - http://www.mellanox.com/downloads/ofed/MLNX_OFED-${MLNX_OFED_VERSION}/MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64.tgz | tar xzf - && \
cd MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64 && \
./mlnxofedinstall --user-space-only --without-fw-update --all -q && \
cd ${STAGE_DIR} && \
rm -rf ${STAGE_DIR}/MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64*
##############################################################################
# nv_peer_mem
......@@ -53,16 +53,16 @@ RUN cd ${STAGE_DIR} && \
ENV NV_PEER_MEM_VERSION=1.1
ENV NV_PEER_MEM_TAG=1.1-0
RUN mkdir -p ${STAGE_DIR} && \
git clone https://github.com/Mellanox/nv_peer_memory.git --branch ${NV_PEER_MEM_TAG} ${STAGE_DIR}/nv_peer_memory && \
cd ${STAGE_DIR}/nv_peer_memory && \
./build_module.sh && \
cd ${STAGE_DIR} && \
tar xzf ${STAGE_DIR}/nvidia-peer-memory_${NV_PEER_MEM_VERSION}.orig.tar.gz && \
cd ${STAGE_DIR}/nvidia-peer-memory-${NV_PEER_MEM_VERSION} && \
apt-get update && \
apt-get install -y dkms && \
dpkg-buildpackage -us -uc && \
dpkg -i ${STAGE_DIR}/nvidia-peer-memory_${NV_PEER_MEM_TAG}_all.deb
git clone https://github.com/Mellanox/nv_peer_memory.git --branch ${NV_PEER_MEM_TAG} ${STAGE_DIR}/nv_peer_memory && \
cd ${STAGE_DIR}/nv_peer_memory && \
./build_module.sh && \
cd ${STAGE_DIR} && \
tar xzf ${STAGE_DIR}/nvidia-peer-memory_${NV_PEER_MEM_VERSION}.orig.tar.gz && \
cd ${STAGE_DIR}/nvidia-peer-memory-${NV_PEER_MEM_VERSION} && \
apt-get update && \
apt-get install -y dkms && \
dpkg-buildpackage -us -uc && \
dpkg -i ${STAGE_DIR}/nvidia-peer-memory_${NV_PEER_MEM_TAG}_all.deb
##############################################################################
# OPENMPI
......@@ -70,22 +70,22 @@ RUN mkdir -p ${STAGE_DIR} && \
ENV OPENMPI_BASEVERSION=4.0
ENV OPENMPI_VERSION=${OPENMPI_BASEVERSION}.1
RUN cd ${STAGE_DIR} && \
wget -q -O - https://download.open-mpi.org/release/open-mpi/v${OPENMPI_BASEVERSION}/openmpi-${OPENMPI_VERSION}.tar.gz | tar xzf - && \
cd openmpi-${OPENMPI_VERSION} && \
./configure --prefix=/usr/local/openmpi-${OPENMPI_VERSION} && \
make -j"$(nproc)" install && \
ln -s /usr/local/openmpi-${OPENMPI_VERSION} /usr/local/mpi && \
# Sanity check:
test -f /usr/local/mpi/bin/mpic++ && \
cd ${STAGE_DIR} && \
rm -r ${STAGE_DIR}/openmpi-${OPENMPI_VERSION}
wget -q -O - https://download.open-mpi.org/release/open-mpi/v${OPENMPI_BASEVERSION}/openmpi-${OPENMPI_VERSION}.tar.gz | tar xzf - && \
cd openmpi-${OPENMPI_VERSION} && \
./configure --prefix=/usr/local/openmpi-${OPENMPI_VERSION} && \
make -j"$(nproc)" install && \
ln -s /usr/local/openmpi-${OPENMPI_VERSION} /usr/local/mpi && \
# Sanity check:
test -f /usr/local/mpi/bin/mpic++ && \
cd ${STAGE_DIR} && \
rm -r ${STAGE_DIR}/openmpi-${OPENMPI_VERSION}
ENV PATH=/usr/local/mpi/bin:${PATH} \
LD_LIBRARY_PATH=/usr/local/lib:/usr/local/mpi/lib:/usr/local/mpi/lib64:${LD_LIBRARY_PATH}
LD_LIBRARY_PATH=/usr/local/lib:/usr/local/mpi/lib:/usr/local/mpi/lib64:${LD_LIBRARY_PATH}
# Create a wrapper for OpenMPI to allow running as root by default
RUN mv /usr/local/mpi/bin/mpirun /usr/local/mpi/bin/mpirun.real && \
echo '#!/bin/bash' > /usr/local/mpi/bin/mpirun && \
echo 'mpirun.real --allow-run-as-root --prefix /usr/local/mpi "$@"' >> /usr/local/mpi/bin/mpirun && \
chmod a+x /usr/local/mpi/bin/mpirun
echo '#!/bin/bash' > /usr/local/mpi/bin/mpirun && \
echo 'mpirun.real --allow-run-as-root --prefix /usr/local/mpi "$@"' >> /usr/local/mpi/bin/mpirun && \
chmod a+x /usr/local/mpi/bin/mpirun
##############################################################################
# Python
......@@ -93,14 +93,14 @@ RUN mv /usr/local/mpi/bin/mpirun /usr/local/mpi/bin/mpirun.real && \
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHON_VERSION=3
RUN apt-get install -y python3 python3-dev && \
rm -f /usr/bin/python && \
ln -s /usr/bin/python3 /usr/bin/python && \
curl -O https://bootstrap.pypa.io/get-pip.py && \
rm -f /usr/bin/python && \
ln -s /usr/bin/python3 /usr/bin/python && \
curl -O https://bootstrap.pypa.io/get-pip.py && \
python get-pip.py && \
rm get-pip.py && \
pip install --upgrade pip && \
# Print python an pip version
python -V && pip -V
pip install --upgrade pip && \
# Print python an pip version
python -V && pip -V
RUN pip install pyyaml
RUN pip install ipython
......@@ -114,44 +114,45 @@ RUN pip install tensorflow-gpu==${TENSORFLOW_VERSION}
# Some Packages
##############################################################################
RUN apt-get update && \
apt-get install -y --no-install-recommends \
apt-get install -y --no-install-recommends \
libsndfile-dev \
libcupti-dev \
libjpeg-dev \
libpng-dev \
screen
screen \
libaio-dev
RUN pip install psutil \
yappi \
cffi \
ipdb \
pandas \
matplotlib \
py3nvml \
pyarrow \
graphviz \
astor \
boto3 \
tqdm \
sentencepiece \
msgpack \
requests \
pandas \
sphinx \
sphinx_rtd_theme \
scipy \
numpy \
sklearn \
scikit-learn \
nvidia-ml-py3 \
mpi4py \
cupy-cuda100
yappi \
cffi \
ipdb \
pandas \
matplotlib \
py3nvml \
pyarrow \
graphviz \
astor \
boto3 \
tqdm \
sentencepiece \
msgpack \
requests \
pandas \
sphinx \
sphinx_rtd_theme \
scipy \
numpy \
sklearn \
scikit-learn \
nvidia-ml-py3 \
mpi4py \
cupy-cuda100
##############################################################################
## SSH daemon port inside container cannot conflict with host OS port
###############################################################################
ENV SSH_PORT=2222
RUN cat /etc/ssh/sshd_config > ${STAGE_DIR}/sshd_config && \
sed "0,/^#Port 22/s//Port ${SSH_PORT}/" ${STAGE_DIR}/sshd_config > /etc/ssh/sshd_config
sed "0,/^#Port 22/s//Port ${SSH_PORT}/" ${STAGE_DIR}/sshd_config > /etc/ssh/sshd_config
##############################################################################
# PyTorch
......@@ -168,7 +169,7 @@ RUN pip install tensorboardX==${TENSORBOARDX_VERSION}
# https://stackoverflow.com/a/53926898
##############################################################################
RUN rm -rf /usr/lib/python3/dist-packages/yaml && \
rm -rf /usr/lib/python3/dist-packages/PyYAML-*
rm -rf /usr/lib/python3/dist-packages/PyYAML-*
##############################################################################
## Add deepspeed user
......@@ -186,8 +187,8 @@ USER deepspeed
##############################################################################
RUN git clone https://github.com/microsoft/DeepSpeed.git ${STAGE_DIR}/DeepSpeed
RUN cd ${STAGE_DIR}/DeepSpeed && \
git checkout . && \
git checkout master && \
./install.sh --pip_sudo
git checkout . && \
git checkout master && \
./install.sh --pip_sudo
RUN rm -rf ${STAGE_DIR}/DeepSpeed
RUN python -c "import deepspeed; print(deepspeed.__version__)"
......@@ -232,14 +232,22 @@ Example of ***scheduler***
Enabling and configuring ZeRO memory optimizations
```json
"zero_optimization": {
"stage": [0|1|2],
"stage": [0|1|2|3],
"allgather_partitions": [true|false],
"allgather_bucket_size": 5e8,
"overlap_comm": false,
"reduce_scatter": [true|false],
"reduce_bucket_size": 5e8,
"contiguous_gradients" : [true|false],
"cpu_offload": [true|false]
"cpu_offload": [true|false],
"cpu_offload_params" : [true|false],
"cpu_offload_use_pin_memory" : [true|false],
"stage3_max_live_parameters" : 1e9,
"stage3_max_reuse_distance" : 1e9,
"stage3_prefetch_bucket_size" : 5e8,
"stage3_param_persistence_threshold" : 1e6,
"sub_group_size" : 1e12,
"elastic_checkpoint" : [true|false]
}
```
......@@ -253,7 +261,7 @@ Enabling and configuring ZeRO memory optimizations
| Description | Default |
| --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Chooses different stages of ZeRO Optimizer. Stage 0, 1, and 2 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitiong, respectively. | `0` |
| Chooses different stages of ZeRO Optimizer. Stage 0, 1, 2, and 3 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitioning, and optimizer+gradient+parameter partitioning, respectively. | `0` |
***allgather_partitions***: [boolean]
......@@ -297,6 +305,42 @@ Enabling and configuring ZeRO memory optimizations
| ------------------------------------------------------------------------------------------------------------------------ | ------- |
| Enable offloading of optimizer memory and computation to CPU. This frees up GPU memory for larger models or batch sizes. | `False` |
***cpu_offload_params***: [boolean]
| Description | Default |
| --------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Enable offloading of model parameters to CPU. This frees up GPU memory for larger models or batch sizes. Valid only with stage 3. | `False` |
***cpu_offload_use_pin_memory***: [boolean]
| Description | Default |
| ----------------------------------------------------------------------------------------- | ------- |
| Use pinned CPU memory when offloading. Can improve performance. Valid only with stage 3. | `False` |
***stage3_max_live_parameters***: [integer]
| Description | Default |
| ------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| The maximum number of parameters resident per GPU before releasing. Smaller values use less memory, but perform more communication. | `1e9` |
***stage3_max_reuse_distance***: [integer]
| Description | Default |
| ---------------------------------------------------------------------------------------------------------------- | ------- |
| Do not release a parameter if it will be reused within this threshold of parameters. Smaller values use less memory, but perform more communication. | `1e9` |
***stage3_prefetch_bucket_size***: [integer]
| Description | Default |
| ------------------------------------------------------------------------------------------------------------------------------- | ------- |
| The size of the fixed buffer for prefetching parameters. Smaller values use less memory, but can increase stalls due to communication. | `5e8` |
***stage3_param_persistence_threshold***: [integer]
| Description | Default |
| -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages). | `1e6` |
### Logging
......
......@@ -3,7 +3,7 @@ title: "Zero Redundancy Optimizer (ZeRO)"
---
If you have not done so already, we advise that you read the DeepSpeed tutorials on [Getting Started](/getting-started/) and [Megatron-LM GPT-2](/tutorials/megatron/) before stepping through this tutorial.
In this tutorial, we will apply the ZeRO optimizer to the [Megatron-LM GPT-2](https://github.com/NVIDIA/Megatron-LM) model. ZeRO is a powerful set of memory optimization techniques that enable effective FP16 training of large models with billions of parameters, such as [GPT-2](https://openai.com/blog/better-language-models/) and [Turing-NLG 17B](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/). Compared to the alternative model parallelism approaches for training large models, a key appeal of ZeRO is that no model code modifications are required. As this tutorial will demonstrate, *using ZeRO in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration json*. No code changes are needed.
In this tutorial, we will apply the ZeRO optimizer to the [Megatron-LM GPT-2](https://github.com/NVIDIA/Megatron-LM) model. ZeRO is a powerful set of memory optimization techniques that enable effective FP16 training of large models with trillons of parameters, such as [GPT-2](https://openai.com/blog/better-language-models/) and [Turing-NLG 17B](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/). Compared to the alternative model parallelism approaches for training large models, a key appeal of ZeRO is that no model code modifications are required. As this tutorial will demonstrate, *using ZeRO in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration JSON*. No code changes are needed.
## ZeRO Overview
ZeRO leverages the aggregate computation and memory resources of data parallelism to reduce the memory and compute requirements of each device (GPU) used for model training. ZeRO reduces the memory consumption of each GPU by partitioning the various model training states (weights, gradients, and optimizer states) across the available devices (GPUs and CPUs) in the distributed training hardware. Concretely, ZeRO is being implemented as incremental stages of optimizations, where optimizations in earlier stages are available in the later stages. To deep dive into ZeRO, please see our [paper](https://arxiv.org/abs/1910.02054v3).
......@@ -12,11 +12,13 @@ ZeRO leverages the aggregate computation and memory resources of data parallelis
* **Stage 2**: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states.
* **Stage 3**: The 16-bit model parameters are partitioned across the processes. ZeRO will automatically collect and partition them during the forward and backward passes.
## Training environment
We use the DeepSpeed [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM) GPT-2 code for this exercise. You can step through the Megatron-LM [tutorial](/tutorials/megatron/) to familiarize yourself with the code. We will train the models in this tutorial on [NVIDIA Tesla V100-SXM3 Tensor Core GPUs](https://www.nvidia.com/en-us/data-center/v100/) with 32GB RAM.
## Enabling ZeRO Optimization
To enable ZeRO optimizations for a DeepSpeed model, we simply add the **_zero_optimization_** key to the DeepSpeed json configuration. A full description of configuration knobs of the **zero_optimization** key is available [here](/docs/config-json/#zero-optimizations-for-fp16-training).
To enable ZeRO optimizations for a DeepSpeed model, we simply add the **_zero_optimization_** key to the DeepSpeed JSON configuration. A full description of configuration knobs of the **zero_optimization** key is available [here](/docs/config-json/#zero-optimizations-for-fp16-training).
### Training a 1.5B Parameter GPT-2 model
We demonstrate the benefits of ZeRO stage 1 by showing that it enables data parallel training of a 1.5 billion parameter GPT-2 model on eight V100 GPUs. We configure training to use a batch size of 1 per device to ensure that the memory consumption is primarily due to model parameters and optimizer states. We create this training scenario by applying the following modifications to the deepspeed launch script:
......@@ -36,7 +38,7 @@ Training this model without ZeRO fails with an out-of-memory (OOM) error as show
<img src="/assets/images/oom_dp8_1.5B_log.png">
</a>
A key reason why this model does not fit in GPU memory is that the Adam optimizer states for the model consume 18GB; a significant portion of the 32GB RAM. By using ZeRO stage 1 to partition the optimizer state among eight data parallel ranks, the per-device memory consumption can be reduced to 2.25GB, thus making the model trainable. To enable ZeRO stage 1, we simply update the DeepSpeed json config file as below:
A key reason why this model does not fit in GPU memory is that the Adam optimizer states for the model consume 18GB; a significant portion of the 32GB RAM. By using ZeRO stage 1 to partition the optimizer state among eight data parallel ranks, the per-device memory consumption can be reduced to 2.25GB, thus making the model trainable. To enable ZeRO stage 1, we simply update the DeepSpeed JSON config file as below:
```json
{
......@@ -75,7 +77,7 @@ First, we need to configure a 10B parameter model with activation checkpointing
--checkpoint-activations
```
Next, we need to update the DeepSpeed json configuration, as shown below, to enable ZeRO stage 2 optimizations:
Next, we need to update the DeepSpeed JSON configuration, as shown below, to enable ZeRO stage 2 optimizations:
```json
{
......@@ -104,4 +106,159 @@ Here is a screenshot of nvidia-smi showing GPU activity during training:
<img src="/assets/images/zero2_dp32_10B_smi.png">
</a>
### Training trillion-scale models with ZeRO-3 Offload
Stage 3 can be enabled in the JSON configuration. A full description of these
configurations is available [here](/docs/config-json/#zero-optimizations-for-fp16-training).
```json
{
"zero_optimization": {
"stage": 3,
"cpu_offload": true,
"cpu_offload_params": true,
"overlap_comm": true,
"contiguous_gradients": true,
"stage3_max_live_parameters": 6000000,
"stage3_max_reuse_distance": 100000000,
"stage3_prefetch_bucket_size": 200000,
"stage3_param_persitance_threshold": 100000,
"reduce_bucket_size": 3000000,
"sub_group_size": 1e6
}
}
```
ZeRO-3 will automatically collect and partition the parameters as they are
needed during the forward and backward passes. However, in some cases a
parameter may be used outside of its module's forward pass. We call these
*external parameters*. ZeRO-3 can coordinate these parameters if they are
registered. Please see our [ZeRO-3 docs](https://deepspeed.readthedocs.io/en/latest/zero3.html) for more
information and examples of external parameters.
The Megatron-LM model has three external parameters that must be registered
with ZeRO-3. External parameters are those that are accessed outside of the
owning module's forward pass.
1. `megatron/model/gpt2_model.py:GPT2Model`: register the word embedding for both uses in forward.
```python
class GPT2Model(MegatronModule):
def __init__(self, num_tokentypes=0, parallel_output=True):
...
deepspeed.zero.register_external_parameter(self,
self.language_model.embedding.word_embeddings.weight)
def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
# self.embeddings will compute its forward pass here
lm_output = self.language_model(input_ids,
position_ids,
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value)
...
# Accesses word_embeddings.weight outside of the embedding's forward pass.
output = parallel_lm_logits(
lm_output,
self.language_model.embedding.word_embeddings.weight,
parallel_output)
```
2. `megatron/model/transformer.py:ParallelMLP`: register a bias that is
returned from a submodule forward and used in this forward.
```python
class ParallelMLP(MegatronModule):
def __init__(self, init_method, output_layer_init_method):
...
if self.dense_h_to_4h.bias is not None:
deepspeed.zero.register_external_parameter(self, self.dense_h_to_4h.bias)
def forward(self, hidden_states):
# bias_parallel is a parameter of dense_h_to_4h
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
...
```
3. `megatron/model/transformer.py:ParallelTransformerLayer`: register two biases that
are returned from submodules and used in forward.
```python
class ParallelTransformerLayer(MegatronModule):
...
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number):
...
if self.attention.dense.bias is not None:
deepspeed.zero.register_external_parameter(self, self.attention.dense.bias)
if self.mlp.dense_4h_to_h.bias is not None:
deepspeed.zero.register_external_parameter(self, self.mlp.dense_4h_to_h.bias)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
...
# attention_bias is a parameter returned from attention
# Self attention.
attention_output, attention_bias = \
self.attention(layernorm_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
...
# mlp_bias is a parameter returned from mlp
mlp_output, mlp_bias = self.mlp(layernorm_output)
...
```
#### Allocating Massive Megatron-LM Models
We make two further changes to model initalization in order to support models
that exceed *local* system memory, but not not *total* system memory.
1. Allocate the model in a memory-scalable fashion. The model parameters will
be allocated and immediately partitioned across the data parallel group. If
`remote_device="cpu"`, the model will also be allocated in CPU memory
instead of GPU memory. Please see the full
[ZeRO-3 Init docs](https://deepspeed.readthedocs.io/en/latest/zero3.html#deepspeed.zero.Init)
for more details.
```python
with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
remote_device=get_args().remote_device,
enabled=get_args().zero_stage==3):
model = GPT2Model(num_tokentypes=0, parallel_output=True)
```
2. Gather the position embeddings weight for initialization. DeepSpeed will automatically
gather a module's parameters during its constructor and for its forward and backward pass.
However, additional accesses must coordinate with DeepSpeed to ensure that parameter data
is gathered and subsequently partitioned. If the tensor is modified, the `modifier_rank`
argument should also be used to ensure all ranks have a consistent view of
the data. Please see the full
[GatheredParameters docs](https://deepspeed.readthedocs.io/en/latest/zero3.html#deepspeed.zero.GatheredParameters)
for more details.
```python
self.position_embeddings = torch.nn.Embedding(...)
with deepspeed.zero.GatheredParameters(self.position_embeddings.weight,
modifier_rank=0):
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
```
Congratulations! You have completed the ZeRO tutorial.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment