Unverified Commit 599258f9 authored by Samyam Rajbhandari's avatar Samyam Rajbhandari Committed by GitHub
Browse files

ZeRO 3 Offload (#834)



* Squash stage3 v1 (#146)
Co-authored-by: default avatarSamyam <samyamr@microsoft.com>
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
Co-authored-by: default avatarSamyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: default avatarOlatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: default avatarShaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: default avatarShaden Smith <ShadenTSmith@gmail.com>
Co-authored-by: default avatareltonzheng <eltonz@microsoft.com>

* Fix correctness bug (#147)

* formatting fix (#150)

* stage3 bugfix (API) update and simplified FP16 Z3 tests (#151)

* fp16 Z3 API update and bugfix

* revert debug change

* ZeRO-3 detach and race condition bugfixes (#149)

* trying out ZeRO-3 race condition fix

* CUDA sync instead of stream

* reduction stream sync

* remove commented code

* Fix optimizer state_dict KeyError (#148)
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>

* fix for smaller SGS sizes, ensures each grad is backed by unique tensors (#152)

* Simplifying the logic for getting averaged gradients (#153)

* skip for now

* Z3 Docs redux (#154)

* removing some TODOs and commented code (#155)

* New Z3 defaults (#156)
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>

* formatting

* megatron external params
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
Co-authored-by: default avatarOlatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: default avatarShaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: default avatarShaden Smith <ShadenTSmith@gmail.com>
Co-authored-by: default avatareltonzheng <eltonz@microsoft.com>
parent ba33e86e
...@@ -48,4 +48,4 @@ jobs: ...@@ -48,4 +48,4 @@ jobs:
- name: Unit tests - name: Unit tests
run: | run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose tests/unit/
...@@ -16,6 +16,8 @@ from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConf ...@@ -16,6 +16,8 @@ from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConf
from .utils import log_dist from .utils import log_dist
from .utils.distributed import init_distributed from .utils.distributed import init_distributed
from .runtime import zero
from .pipe import PipelineModule from .pipe import PipelineModule
from .git_version_info import version, git_hash, git_branch from .git_version_info import version, git_hash, git_branch
......
...@@ -304,7 +304,7 @@ def main(args=None): ...@@ -304,7 +304,7 @@ def main(args=None):
# encode world info as base64 to make it easier to pass via command line # encode world info as base64 to make it easier to pass via command line
world_info_base64 = encode_world_info(active_resources) world_info_base64 = encode_world_info(active_resources)
multi_node_exec = len(active_resources) > 1 multi_node_exec = True # len(active_resources) > 1
if multi_node_exec and not shutil.which('pdsh'): if multi_node_exec and not shutil.which('pdsh'):
raise RuntimeError("pdsh is not installed, unable to proceed") raise RuntimeError("pdsh is not installed, unable to proceed")
......
...@@ -10,12 +10,24 @@ from ..op_builder import CPUAdamBuilder ...@@ -10,12 +10,24 @@ from ..op_builder import CPUAdamBuilder
class DeepSpeedCPUAdam(torch.optim.Optimizer): class DeepSpeedCPUAdam(torch.optim.Optimizer):
optimizer_id = 0
def __init__(self,
model_params,
lr=1e-3,
bias_correction=True,
betas=(0.9,
0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
adamw_mode=True):
"""Fast vectorized implementation of two variations of Adam optimizer on CPU: """Fast vectorized implementation of two variations of Adam optimizer on CPU:
- Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980); * Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980);
- AdamW: FIXING WEIGHT DECAY REGULARIZATION IN ADAM (https://arxiv.org/abs/1711.05101v1) * AdamW: Fixing Weight Decay Regularization in Adam (https://arxiv.org/abs/1711.05101)
DeepSpeed CPU Adam(W) provides between 5x to 7x speedu over torch.optim.adam(W). DeepSpeed CPU Adam(W) provides between 5x to 7x speedup over torch.optim.adam(W).
In order to apply this optimizer, the model requires to have its master parameter (in FP32) In order to apply this optimizer, the model requires to have its master parameter (in FP32)
reside on the CPU memory. reside on the CPU memory.
...@@ -30,6 +42,13 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer): ...@@ -30,6 +42,13 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
option can bring 30% higher throughput than the doing the copy separately using option one. option can bring 30% higher throughput than the doing the copy separately using option one.
.. note::
We recommend using our `config
<https://www.deepspeed.ai/docs/config-json/#optimizer-parameters>`_
to allow :meth:`deepspeed.initialize` to build this optimizer
for you.
Arguments: Arguments:
model_params (iterable): iterable of parameters to optimize or dicts defining model_params (iterable): iterable of parameters to optimize or dicts defining
parameter groups. parameter groups.
...@@ -45,19 +64,6 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer): ...@@ -45,19 +64,6 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
adamw_mode: select between Adam and AdamW implementations (default: AdamW) adamw_mode: select between Adam and AdamW implementations (default: AdamW)
""" """
optimizer_id = 0
def __init__(self,
model_params,
lr=1e-3,
bias_correction=True,
betas=(0.9,
0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
adamw_mode=True):
default_args = dict(lr=lr, default_args = dict(lr=lr,
betas=betas, betas=betas,
eps=eps, eps=eps,
...@@ -86,6 +92,24 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer): ...@@ -86,6 +92,24 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
@torch.no_grad() @torch.no_grad()
def step(self, closure=None, fp16_param_groups=None): def step(self, closure=None, fp16_param_groups=None):
"""Update the model parameters.
.. note::
This method will be called internally by ZeRO-Offload. DeepSpeed
users should still use ``engine.step()`` as shown in the
`Getting Started
<https://www.deepspeed.ai/getting-started/#training>`_ guide.
Args:
closure (callable, optional): closure to compute the loss.
Defaults to ``None``.
fp16_param_groups: FP16 GPU parameters to update. Performing the
copy here reduces communication time. Defaults to ``None``.
Returns:
loss: if ``closure`` is provided. Otherwise ``None``.
"""
loss = None loss = None
if closure is not None: if closure is not None:
with torch.enable_grad(): with torch.enable_grad():
...@@ -100,7 +124,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer): ...@@ -100,7 +124,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
print(f'group {group_id} param {param_id} = {p.numel()}') #print(f'group {group_id} param {param_id} = {p.numel()}')
state['step'] = 0 state['step'] = 0
# gradient momentums # gradient momentums
state['exp_avg'] = torch.zeros_like(p.data, state['exp_avg'] = torch.zeros_like(p.data,
......
...@@ -18,6 +18,7 @@ import torch ...@@ -18,6 +18,7 @@ import torch
import contextlib import contextlib
import torch.distributed as dist import torch.distributed as dist
import mmap
from torch import _C from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager from torch.cuda import _lazy_call, device as device_ctx_manager
...@@ -26,19 +27,19 @@ from deepspeed.utils import logger ...@@ -26,19 +27,19 @@ from deepspeed.utils import logger
from deepspeed.runtime.utils import move_to_device from deepspeed.runtime.utils import move_to_device
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers
#DeepSpeed Checkpointing Enabled or Disabled # DeepSpeed Checkpointing Enabled or Disabled
deepspeed_checkpointing_enabled = False deepspeed_checkpointing_enabled = False
#MP parameters # MP parameters
mpu = None mpu = None
mp_rank = None mp_rank = None
mp_size = None mp_size = None
mp_group = None mp_group = None
#Model Parameters # Model Parameters
num_layers = None num_layers = None
#Checkpointing buffers # Checkpointing buffers
contiguous_data_buffers = [] contiguous_data_buffers = []
data_offsets = [] data_offsets = []
...@@ -47,7 +48,7 @@ size_offsets = [] ...@@ -47,7 +48,7 @@ size_offsets = []
timers = None timers = None
#optimization flags # optimization flags
PARTITION_ACTIVATIONS = False PARTITION_ACTIVATIONS = False
PA_TO_CPU = False PA_TO_CPU = False
CONTIGUOUS_CHECKPOINTING = False CONTIGUOUS_CHECKPOINTING = False
...@@ -56,10 +57,10 @@ PROFILE_TIME = False ...@@ -56,10 +57,10 @@ PROFILE_TIME = False
def see_memory_usage(message, force=False): def see_memory_usage(message, force=False):
#return # return
if not force: if not force:
return return
#dist.barrier() # dist.barrier()
if dist.get_rank() == 0: if dist.get_rank() == 0:
logger.info(message) logger.info(message)
logger.info( logger.info(
...@@ -78,6 +79,7 @@ def see_memory_usage(message, force=False): ...@@ -78,6 +79,7 @@ def see_memory_usage(message, force=False):
"Max cache Allocated %s GigaBytes", "Max cache Allocated %s GigaBytes",
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024), torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
) )
logger.info("")
#input("Press Any Key To Continue ..") #input("Press Any Key To Continue ..")
...@@ -348,7 +350,22 @@ def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags): ...@@ -348,7 +350,22 @@ def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags):
tensor_idx = 0 tensor_idx = 0
non_tensor_idx = 0 non_tensor_idx = 0
for is_tensor in tensor_flags: real_tensor_flags = None
#remove the flags that are assigned to the size of the flattened tensors
if PARTITION_ACTIVATIONS:
real_tensor_flags = []
previous_flag = False
for flag in tensor_flags:
if previous_flag:
previous_flag = False
continue
previous_flag = flag
real_tensor_flags.append(flag)
else:
real_tensor_flags = tensor_flags
for is_tensor in real_tensor_flags:
if is_tensor: if is_tensor:
merged_objects.append(tensor_objects[tensor_idx]) merged_objects.append(tensor_objects[tensor_idx])
tensor_idx += 1 tensor_idx += 1
...@@ -406,7 +423,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -406,7 +423,7 @@ class CheckpointFunction(torch.autograd.Function):
global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset
if cuda_device is None: if cuda_device is None:
see_memory_usage("First Forward Begining", force=True) see_memory_usage("First Forward Begining", force=False)
if dist.get_rank() == 0: if dist.get_rank() == 0:
logger.info(f"Activation Checkpointing Information") logger.info(f"Activation Checkpointing Information")
logger.info( logger.info(
...@@ -423,7 +440,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -423,7 +440,7 @@ class CheckpointFunction(torch.autograd.Function):
if PARTITION_ACTIVATIONS: if PARTITION_ACTIVATIONS:
#inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]] #inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]]
#inputs.append(args[-1]) # inputs.append(args[-1])
inputs = [] inputs = []
for i, item in enumerate(args[:-1]): for i, item in enumerate(args[:-1]):
...@@ -460,6 +477,19 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -460,6 +477,19 @@ class CheckpointFunction(torch.autograd.Function):
contiguous_data_buffers[i] = tensor_list contiguous_data_buffers[i] = tensor_list
data_offsets[i] = 0 data_offsets[i] = 0
# Because the 'new_empty' returns uninitialized pages,
# the pages need to be populated during the cudaMemcpy time
# which increases the data copy time. To avoid this, we
# pre-populate these pages by simply writing 0 ahead of
# the actual cudaMemcpy operation time. Due to the
# previously launched GPU kernels, there is a small
# window of time here for CPUs to populate pages asynchronously.
contiguous_data_buffers[i][data_offsets[i]].data[range(
0,
contiguous_data_buffers[i][data_offsets[i]].data.shape[0],
int(mmap.PAGESIZE / contiguous_data_buffers[i][
data_offsets[i]].data.element_size()))] = 0
contiguous_partition = contiguous_data_buffers[i][ contiguous_partition = contiguous_data_buffers[i][
data_offsets[i]].data.copy_(partition.data) data_offsets[i]].data.copy_(partition.data)
data_offsets[i] = data_offsets[i] + 1 data_offsets[i] = data_offsets[i] + 1
...@@ -478,14 +508,16 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -478,14 +508,16 @@ class CheckpointFunction(torch.autograd.Function):
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
#ctx.save_for_backward(*args) see_memory_usage("Before running forward on the layer", force=False)
# ctx.save_for_backward(*args)
with torch.no_grad(): with torch.no_grad():
outputs = run_function(*inputs_cuda) outputs = run_function(*inputs_cuda)
see_memory_usage("After running forward on the layer", force=False)
del inputs_cuda del inputs_cuda
#with torch.cuda.stream(transport_stream): # with torch.cuda.stream(transport_stream):
#if PARTITION_ACTIVATIONS: # if PARTITION_ACTIVATIONS:
# new_args = [] # new_args = []
# for arg, inp in zip(args,inputs): # for arg, inp in zip(args,inputs):
# size= torch.tensor(arg.size()) # size= torch.tensor(arg.size())
...@@ -531,7 +563,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -531,7 +563,7 @@ class CheckpointFunction(torch.autograd.Function):
new_args.append(contiguous_size) new_args.append(contiguous_size)
else: else:
new_args.append(size) new_args.append(size)
#if dist.get_rank() == 0: # if dist.get_rank() == 0:
# logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ") # logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ")
save_args_for_backward(*new_args) save_args_for_backward(*new_args)
...@@ -564,10 +596,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -564,10 +596,10 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *grads): def backward(ctx, *grads):
global timers global timers
#see_memory_usage("In backward", force=True) see_memory_usage("In backward", force=False)
#removing pointers to the contiguous buffer memory # removing pointers to the contiguous buffer memory
#so that they can be garbage collected once the checkpoints # so that they can be garbage collected once the checkpoints
#have been used # have been used
if SYNCHRONIZE: if SYNCHRONIZE:
torch.cuda.synchronize() torch.cuda.synchronize()
if PROFILE_TIME: if PROFILE_TIME:
...@@ -580,14 +612,14 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -580,14 +612,14 @@ class CheckpointFunction(torch.autograd.Function):
for buffers in contiguous_data_buffers: for buffers in contiguous_data_buffers:
buffers = [] buffers = []
#frees up all the pointers to the checkpoints except for the ones # frees up all the pointers to the checkpoints except for the ones
#stored by save for backward # stored by save for backward
contiguous_data_buffers = [] contiguous_data_buffers = []
contiguous_size_buffers = [] contiguous_size_buffers = []
data_offsets = [] data_offsets = []
size_offsets = [] size_offsets = []
#see_memory_usage("In backward checkpointing code", force=True) see_memory_usage("In backward checkpointing code", force=False)
if not torch.autograd._is_checkpoint_valid(): if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), " raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible") "please use .backward() if possible")
...@@ -595,7 +627,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -595,7 +627,7 @@ class CheckpointFunction(torch.autograd.Function):
global cuda_device, transport_stream, PARTITION_ACTIVATIONS global cuda_device, transport_stream, PARTITION_ACTIVATIONS
if PARTITION_ACTIVATIONS: if PARTITION_ACTIVATIONS:
#with torch.cuda.stream(transport_stream): # with torch.cuda.stream(transport_stream):
inputs = get_full_inputs(ctx.saved_tensors, inputs = get_full_inputs(ctx.saved_tensors,
device=cuda_device if PA_TO_CPU else None) device=cuda_device if PA_TO_CPU else None)
detached_inputs = detach_variable(inputs) detached_inputs = detach_variable(inputs)
...@@ -622,9 +654,12 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -622,9 +654,12 @@ class CheckpointFunction(torch.autograd.Function):
# current_stream=torch.cuda.current_stream() # current_stream=torch.cuda.current_stream()
# current_stream.wait_stream(transport_stream) # current_stream.wait_stream(transport_stream)
see_memory_usage("In backward checkpointing code before forward", force=False)
with torch.enable_grad(): with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs) outputs = ctx.run_function(*detached_inputs)
see_memory_usage("In backward checkpointing code after forward", force=False)
# Set the states back to what it was at the start of this function. # Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state) torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state) _set_cuda_rng_state(bwd_cuda_rng_state)
...@@ -646,8 +681,13 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -646,8 +681,13 @@ class CheckpointFunction(torch.autograd.Function):
output_tensors.append(out) output_tensors.append(out)
grad_tensors.append(grad) grad_tensors.append(grad)
see_memory_usage("In backward checkpointing code before backward", force=False)
torch.autograd.backward(output_tensors, grad_tensors) torch.autograd.backward(output_tensors, grad_tensors)
see_memory_usage("After backward checkpointing code before backward",
force=False)
if PROFILE_TIME: if PROFILE_TIME:
timers('backward').stop() timers('backward').stop()
timers.log(['backward']) timers.log(['backward'])
...@@ -706,8 +746,8 @@ def reset(): ...@@ -706,8 +746,8 @@ def reset():
for buffers in contiguous_data_buffers: for buffers in contiguous_data_buffers:
buffers = [] buffers = []
#frees up all the pointers to the checkpoints except for the ones # frees up all the pointers to the checkpoints except for the ones
#stored by save for backward # stored by save for backward
contiguous_data_buffers = [] contiguous_data_buffers = []
contiguous_size_buffers = [] contiguous_size_buffers = []
data_offsets = [] data_offsets = []
...@@ -719,6 +759,7 @@ def _configure_using_config_file(deepspeed_config, mpu=None): ...@@ -719,6 +759,7 @@ def _configure_using_config_file(deepspeed_config, mpu=None):
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
config = DeepSpeedConfig(deepspeed_config, mpu=mpu).activation_checkpointing_config config = DeepSpeedConfig(deepspeed_config, mpu=mpu).activation_checkpointing_config
if dist.get_rank() == 0:
logger.info(config.repr()) logger.info(config.repr())
PARTITION_ACTIVATIONS = config.partition_activations PARTITION_ACTIVATIONS = config.partition_activations
CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization
......
...@@ -752,8 +752,8 @@ class DeepSpeedConfig(object): ...@@ -752,8 +752,8 @@ class DeepSpeedConfig(object):
if self.zero_enabled: if self.zero_enabled:
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled" assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION) assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
if self.zero_config.cpu_offload is True: #if self.zero_config.cpu_offload is True:
assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS) # assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS)
#assert self.gradient_accumulation_steps == 1, "DeepSpeedConfig: {}is not supported for {}".format(GRADIENT_ACCUMULATION_STEPS, ZERO_OPTIMIZATION_CPU_OFFLOAD) #assert self.gradient_accumulation_steps == 1, "DeepSpeedConfig: {}is not supported for {}".format(GRADIENT_ACCUMULATION_STEPS, ZERO_OPTIMIZATION_CPU_OFFLOAD)
def _do_warning_check(self): def _do_warning_check(self):
......
...@@ -12,8 +12,10 @@ from torch.nn.modules import Module ...@@ -12,8 +12,10 @@ from torch.nn.modules import Module
from torch.distributed.distributed_c10d import _get_global_rank from torch.distributed.distributed_c10d import _get_global_rank
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from deepspeed.runtime.utils import see_memory_usage
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1 from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer from deepspeed.runtime.zero.utils import is_zero_supported_optimizer
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
...@@ -27,7 +29,7 @@ from deepspeed.runtime.constants import \ ...@@ -27,7 +29,7 @@ from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
PLD_THETA, PLD_GAMMA PLD_THETA, PLD_GAMMA
from deepspeed.runtime.zero.constants import \ from deepspeed.runtime.zero.constants import \
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS, ZERO_OPTIMIZATION_WEIGHTS
from deepspeed.runtime.csr_tensor import CSRTensor from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.utils import logger, log_dist, init_distributed from deepspeed.utils import logger, log_dist, init_distributed
...@@ -105,8 +107,10 @@ class DeepSpeedEngine(Module): ...@@ -105,8 +107,10 @@ class DeepSpeedEngine(Module):
mpu=None, mpu=None,
dist_init_required=None, dist_init_required=None,
collate_fn=None, collate_fn=None,
config_params=None): config_params=None,
dont_change_device=False):
super(DeepSpeedEngine, self).__init__() super(DeepSpeedEngine, self).__init__()
self.dont_change_device = dont_change_device
self.client_optimizer = optimizer self.client_optimizer = optimizer
self.client_model_parameters = model_parameters self.client_model_parameters = model_parameters
self.client_lr_scheduler = lr_scheduler self.client_lr_scheduler = lr_scheduler
...@@ -136,6 +140,7 @@ class DeepSpeedEngine(Module): ...@@ -136,6 +140,7 @@ class DeepSpeedEngine(Module):
# Initialize torch distributed if needed # Initialize torch distributed if needed
init_distributed(dist_backend=self.dist_backend) init_distributed(dist_backend=self.dist_backend)
see_memory_usage(f"DeepSpeed Engine: Before args sanity test")
self._do_args_sanity_check(args) self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu) self._configure_with_arguments(args, mpu)
self._do_sanity_check() self._do_sanity_check()
...@@ -149,9 +154,13 @@ class DeepSpeedEngine(Module): ...@@ -149,9 +154,13 @@ class DeepSpeedEngine(Module):
if self.tensorboard_enabled() and self.global_rank == 0: if self.tensorboard_enabled() and self.global_rank == 0:
self.summary_writer = self.get_summary_writer() self.summary_writer = self.get_summary_writer()
see_memory_usage(f"DeepSpeed Engine: Before configure distributed model")
# Configure distributed model # Configure distributed model
self._configure_distributed_model(model) self._configure_distributed_model(model)
see_memory_usage(f"DeepSpeed Engine: After configure distributed model")
# Configure wall clock timer # Configure wall clock timer
self.timers = SynchronizedWallClockTimer() self.timers = SynchronizedWallClockTimer()
...@@ -331,6 +340,15 @@ class DeepSpeedEngine(Module): ...@@ -331,6 +340,15 @@ class DeepSpeedEngine(Module):
def zero_cpu_offload(self): def zero_cpu_offload(self):
return self._config.zero_config.cpu_offload return self._config.zero_config.cpu_offload
def zero_cpu_offload_params(self):
return self._config.zero_config.cpu_offload_params
def zero_cpu_offload_use_pin_memory(self):
return self._config.zero_config.cpu_offload_use_pin_memory
def zero_sub_group_size(self):
return self._config.zero_config.sub_group_size
def zero_optimization_stage(self): def zero_optimization_stage(self):
return self._config.zero_optimization_stage return self._config.zero_optimization_stage
...@@ -343,6 +361,9 @@ class DeepSpeedEngine(Module): ...@@ -343,6 +361,9 @@ class DeepSpeedEngine(Module):
def zero_optimization_partition_gradients(self): def zero_optimization_partition_gradients(self):
return self.zero_optimization_stage() >= ZERO_OPTIMIZATION_GRADIENTS return self.zero_optimization_stage() >= ZERO_OPTIMIZATION_GRADIENTS
def zero_optimization_partition_weights(self):
return self.zero_optimization_stage() >= ZERO_OPTIMIZATION_WEIGHTS
def zero_contiguous_gradients(self): def zero_contiguous_gradients(self):
return self._config.zero_config.contiguous_gradients return self._config.zero_config.contiguous_gradients
...@@ -352,6 +373,18 @@ class DeepSpeedEngine(Module): ...@@ -352,6 +373,18 @@ class DeepSpeedEngine(Module):
def zero_elastic_checkpoint(self): def zero_elastic_checkpoint(self):
return self._config.zero_config.elastic_checkpoint return self._config.zero_config.elastic_checkpoint
def zero_max_live_parameters(self):
return self._config.zero_config.max_live_parameters
def zero_max_reuse_distance(self):
return self._config.zero_config.max_reuse_distance
def zero_prefetch_bucket_size(self):
return self._config.zero_config.prefetch_bucket_size
def zero_param_persistence_threshold(self):
return self._config.zero_config.param_persistence_threshold
def fp16_enabled(self): def fp16_enabled(self):
return self._config.fp16_enabled return self._config.fp16_enabled
...@@ -418,7 +451,8 @@ class DeepSpeedEngine(Module): ...@@ -418,7 +451,8 @@ class DeepSpeedEngine(Module):
dp_rank = self.mpu.get_data_parallel_rank() dp_rank = self.mpu.get_data_parallel_rank()
# only the first data parallel process needs to store the model checkpoint # only the first data parallel process needs to store the model checkpoint
self.save_non_zero_checkpoint = (dp_rank == 0) self.save_non_zero_checkpoint = (
dp_rank == 0) or self.zero_optimization_partition_weights()
if self.zero_optimization(): if self.zero_optimization():
param_rank = torch.distributed.get_rank( param_rank = torch.distributed.get_rank(
...@@ -512,8 +546,13 @@ class DeepSpeedEngine(Module): ...@@ -512,8 +546,13 @@ class DeepSpeedEngine(Module):
'DeepSpeed {} optimizer requires dynamic loss scaling'.format(self.optimizer_name()) 'DeepSpeed {} optimizer requires dynamic loss scaling'.format(self.optimizer_name())
def _broadcast_model(self): def _broadcast_model(self):
def is_replicated(p):
if hasattr(p, 'ds_status') and p.ds_status is not ZeroParamStatus.AVAILABLE:
return False
return True
for p in self.module.parameters(): for p in self.module.parameters():
if torch.is_tensor(p): if torch.is_tensor(p) and is_replicated(p):
dist.broadcast(p, dist.broadcast(p,
self.broadcast_src_rank, self.broadcast_src_rank,
group=self.data_parallel_group) group=self.data_parallel_group)
...@@ -522,6 +561,8 @@ class DeepSpeedEngine(Module): ...@@ -522,6 +561,8 @@ class DeepSpeedEngine(Module):
self.module = model self.module = model
if self.fp16_enabled(): if self.fp16_enabled():
self.module.half() self.module.half()
if not self.dont_change_device:
self.module.to(self.device) self.module.to(self.device)
if self.mpu is None: if self.mpu is None:
...@@ -555,7 +596,8 @@ class DeepSpeedEngine(Module): ...@@ -555,7 +596,8 @@ class DeepSpeedEngine(Module):
self.optimizer_name())) self.optimizer_name()))
if self.global_rank == 0: if self.global_rank == 0:
logger.info('DeepSpeed Basic Optimizer = {}'.format(basic_optimizer)) logger.info('DeepSpeed Basic Optimizer = {}'.format(
basic_optimizer.__class__.__name__))
if self.zero_optimization(): if self.zero_optimization():
assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2" assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
...@@ -585,7 +627,8 @@ class DeepSpeedEngine(Module): ...@@ -585,7 +627,8 @@ class DeepSpeedEngine(Module):
self.optimizer = self._configure_fp16_optimizer(basic_optimizer) self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
else: else:
self.optimizer = basic_optimizer self.optimizer = basic_optimizer
logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer)) log_dist('DeepSpeed Final Optimizer = {}'.format(self.optimizer_name()),
ranks=[0])
def _configure_basic_optimizer(self, model_parameters): def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params() optimizer_parameters = self.optimizer_params()
...@@ -636,7 +679,7 @@ class DeepSpeedEngine(Module): ...@@ -636,7 +679,7 @@ class DeepSpeedEngine(Module):
if isinstance(optimizer, if isinstance(optimizer,
FusedAdam) or self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER: FusedAdam) or self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
if self.dynamic_loss_scale(): if self.dynamic_loss_scale():
logger.info('Creating fp16 optimizer with dynamic loss scale') log_dist('Creating fp16 optimizer with dynamic loss scale', ranks=[0])
timers = self.timers if self.wall_clock_breakdown() else None timers = self.timers if self.wall_clock_breakdown() else None
optimizer = FP16_Optimizer( optimizer = FP16_Optimizer(
optimizer, optimizer,
...@@ -648,8 +691,9 @@ class DeepSpeedEngine(Module): ...@@ -648,8 +691,9 @@ class DeepSpeedEngine(Module):
fused_adam_legacy=self.optimizer_legacy_fusion(), fused_adam_legacy=self.optimizer_legacy_fusion(),
timers=timers) timers=timers)
else: else:
logger.info('Creating fp16 optimizer with static loss scale: {}'.format( log_dist('Creating fp16 optimizer with static loss scale: {}'.format(
self.loss_scale())) self.loss_scale()),
ranks=[0])
optimizer = FP16_Optimizer( optimizer = FP16_Optimizer(
optimizer, optimizer,
static_loss_scale=self.loss_scale(), static_loss_scale=self.loss_scale(),
...@@ -657,7 +701,8 @@ class DeepSpeedEngine(Module): ...@@ -657,7 +701,8 @@ class DeepSpeedEngine(Module):
clip_grad=clip_grad, clip_grad=clip_grad,
fused_adam_legacy=self.optimizer_legacy_fusion()) fused_adam_legacy=self.optimizer_legacy_fusion())
else: else:
logger.info('Creating fp16 unfused optimizer with dynamic loss scale') log_dist('Creating fp16 unfused optimizer with dynamic loss scale',
ranks=[0])
optimizer = FP16_UnfusedOptimizer( optimizer = FP16_UnfusedOptimizer(
optimizer, optimizer,
static_loss_scale=self.loss_scale(), static_loss_scale=self.loss_scale(),
...@@ -671,8 +716,9 @@ class DeepSpeedEngine(Module): ...@@ -671,8 +716,9 @@ class DeepSpeedEngine(Module):
def _configure_zero_optimizer(self, optimizer): def _configure_zero_optimizer(self, optimizer):
zero_stage = self.zero_optimization_stage() zero_stage = self.zero_optimization_stage()
logger.info('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage)) log_dist('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage), ranks=[0])
assert not self.allreduce_always_fp32(), "ZeRO does not support 'fp32_allreduce': true" assert not self.allreduce_always_fp32(), "ZeRO does not support 'fp32_allreduce': true"
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES: if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode' assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode'
optimizer = FP16_DeepSpeedZeroOptimizer_Stage1( optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
...@@ -706,6 +752,35 @@ class DeepSpeedEngine(Module): ...@@ -706,6 +752,35 @@ class DeepSpeedEngine(Module):
postscale_gradients=self.postscale_gradients(), postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(), gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps()) gradient_accumulation_steps=self.gradient_accumulation_steps())
elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS:
print("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
optimizer = FP16_DeepSpeedZeroOptimizer_Stage3(
self.module,
optimizer,
timers=self.timers,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=self.zero_contiguous_gradients(),
reduce_bucket_size=self.zero_reduce_bucket_size(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
cpu_offload_optimizer_state=self.zero_cpu_offload(),
cpu_offload_params=self.zero_cpu_offload_params(),
cpu_offload_use_pin_memory=self.zero_cpu_offload_use_pin_memory(),
sub_group_size=self.zero_sub_group_size(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps())
else: else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage)) raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
...@@ -817,6 +892,11 @@ class DeepSpeedEngine(Module): ...@@ -817,6 +892,11 @@ class DeepSpeedEngine(Module):
self.tput_timer.start() self.tput_timer.start()
loss = self.module(*inputs, **kwargs) loss = self.module(*inputs, **kwargs)
# Reset the ZeRO-3 state if we are only doing forward-passes (ie evaluation).
if self.zero_optimization_partition_weights():
if not torch._C.is_grad_enabled():
self.optimizer.param_coordinator.reset_step()
if self.wall_clock_breakdown(): if self.wall_clock_breakdown():
self.timers('forward').stop() self.timers('forward').stop()
self.timers('forward_microstep').stop() self.timers('forward_microstep').stop()
...@@ -1267,7 +1347,16 @@ class DeepSpeedEngine(Module): ...@@ -1267,7 +1347,16 @@ class DeepSpeedEngine(Module):
def _get_ckpt_name(self, checkpoints_path, tag): def _get_ckpt_name(self, checkpoints_path, tag):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
ckpt_name = os.path.join(checkpoints_path, if self.zero_optimization_partition_weights():
filename = 'zero_pp_rank_{}'.format(
torch.distributed.get_rank(group=self.optimizer.dp_process_group))
ckpt_name = os.path.join(
checkpoints_path,
str(tag),
filename + '_mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')
else:
ckpt_name = os.path.join(
checkpoints_path,
str(tag), str(tag),
'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt') 'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')
return ckpt_name return ckpt_name
...@@ -1478,6 +1567,10 @@ class DeepSpeedEngine(Module): ...@@ -1478,6 +1567,10 @@ class DeepSpeedEngine(Module):
process with rank 0. process with rank 0.
""" """
if self.zero_optimization_partition_weights():
# Prepare for state_dict() by ensuring all parameters are partitioned
self.optimizer.save_checkpoint_prologue()
# This is to make sure the checkpoint names are created without collision # This is to make sure the checkpoint names are created without collision
# There seems to be issue creating them in parallel # There seems to be issue creating them in parallel
...@@ -1506,6 +1599,9 @@ class DeepSpeedEngine(Module): ...@@ -1506,6 +1599,9 @@ class DeepSpeedEngine(Module):
with open(os.path.join(save_dir, 'latest'), 'w') as fd: with open(os.path.join(save_dir, 'latest'), 'w') as fd:
fd.write(tag) fd.write(tag)
if self.zero_optimization_partition_weights():
self.optimizer.save_checkpoint_epilogue()
return True return True
def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):
......
...@@ -7,6 +7,7 @@ Helper functions and classes from multiple sources. ...@@ -7,6 +7,7 @@ Helper functions and classes from multiple sources.
''' '''
import os import os
import psutil
from math import ceil from math import ceil
from math import floor from math import floor
from bisect import bisect_left, bisect_right from bisect import bisect_left, bisect_right
...@@ -72,7 +73,7 @@ class CheckOverflow(object): ...@@ -72,7 +73,7 @@ class CheckOverflow(object):
self.params.append(param) self.params.append(param)
def check_using_norm(self, norm_group, reduce_overflow=True): def check_using_norm(self, norm_group, reduce_overflow=True):
#TODO: I don't think reduce_overflow is needed if mpu is None # TODO: I don't think reduce_overflow is needed if mpu is None
overflow = -1 in norm_group overflow = -1 in norm_group
if self.mpu is not None: if self.mpu is not None:
...@@ -115,7 +116,7 @@ class CheckOverflow(object): ...@@ -115,7 +116,7 @@ class CheckOverflow(object):
# Since each model parallel GPU carries only part of the model, # Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs # make sure overflow flag is synced across all the model parallel GPUs
overflow_gpu = torch.cuda.ByteTensor([overflow]) overflow_gpu = torch.cuda.ByteTensor([overflow])
#torch.distributed.all_reduce(overflow_gpu, # torch.distributed.all_reduce(overflow_gpu,
# op=torch.distributed.ReduceOp.MAX, # op=torch.distributed.ReduceOp.MAX,
# group=mpu.get_model_parallel_group()) # group=mpu.get_model_parallel_group())
if self.zero_reduce_scatter: if self.zero_reduce_scatter:
...@@ -544,7 +545,8 @@ def memory_status(msg, print_rank=-1, reset_max=False): ...@@ -544,7 +545,8 @@ def memory_status(msg, print_rank=-1, reset_max=False):
) )
def see_memory_usage(message): def see_memory_usage(message, force=False):
if not force:
return return
if torch.distributed.is_initialized() and not torch.distributed.get_rank() == 0: if torch.distributed.is_initialized() and not torch.distributed.get_rank() == 0:
return return
...@@ -557,6 +559,11 @@ def see_memory_usage(message): ...@@ -557,6 +559,11 @@ def see_memory_usage(message):
CA {round(torch.cuda.memory_cached() / (1024 * 1024 * 1024),2)} GB \ CA {round(torch.cuda.memory_cached() / (1024 * 1024 * 1024),2)} GB \
Max_CA {round(torch.cuda.max_memory_cached() / (1024 * 1024 * 1024))} GB ") Max_CA {round(torch.cuda.max_memory_cached() / (1024 * 1024 * 1024))} GB ")
vm_stats = psutil.virtual_memory()
used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2)
logger.info(
f'CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%')
def call_to_str(base, *args, **kwargs): def call_to_str(base, *args, **kwargs):
"""Construct a string representation of a call. """Construct a string representation of a call.
......
from .partition_parameters import ZeroParamType
from .partition_parameters import ZeroParamStatus
from .partition_parameters import Init
from .partition_parameters import GatheredParameters
from .partition_parameters import register_external_parameter
...@@ -21,9 +21,27 @@ class DeepSpeedZeroConfig(object): ...@@ -21,9 +21,27 @@ class DeepSpeedZeroConfig(object):
self.allgather_bucket_size = None self.allgather_bucket_size = None
self.overlap_comm = None self.overlap_comm = None
self.load_from_fp32_weights = None self.load_from_fp32_weights = None
self.cpu_offload = None
self.elastic_checkpoint = None self.elastic_checkpoint = None
#Offload Specific Parameters
self.cpu_offload = None
self.cpu_offload_params = None
self.cpu_offload_use_pin_memory = None
self.sub_group_size = None
#Stage3 Specific Parameters
self.prefetch_bucket_size = None
self.param_persistence_threshold = None
self.max_live_parameters = None
self.max_reuse_distance = None
#Stage3 Specific Parameters
self.prefetch_bucket_size = None
self.param_persistence_threshold = None
self.max_live_parameters = None
self.max_reuse_distance = None
if ZERO_OPTIMIZATION in param_dict.keys(): if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION] zero_config_dict = param_dict[ZERO_OPTIMIZATION]
if type(zero_config_dict) is bool: if type(zero_config_dict) is bool:
...@@ -66,6 +84,8 @@ class DeepSpeedZeroConfig(object): ...@@ -66,6 +84,8 @@ class DeepSpeedZeroConfig(object):
self.contiguous_gradients = get_scalar_param( self.contiguous_gradients = get_scalar_param(
zero_config_dict, zero_config_dict,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS, ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS,
ZERO3_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT
if self.stage == ZERO_OPTIMIZATION_WEIGHTS else
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT) ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT)
self.reduce_bucket_size = get_scalar_param( self.reduce_bucket_size = get_scalar_param(
...@@ -77,8 +97,11 @@ class DeepSpeedZeroConfig(object): ...@@ -77,8 +97,11 @@ class DeepSpeedZeroConfig(object):
ZERO_OPTIMIZATION_REDUCE_SCATTER, ZERO_OPTIMIZATION_REDUCE_SCATTER,
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT) ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT)
self.overlap_comm = get_scalar_param(zero_config_dict, self.overlap_comm = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_OVERLAP_COMM, ZERO_OPTIMIZATION_OVERLAP_COMM,
ZERO3_OPTIMIZATION_OVERLAP_COMM_DEFAULT
if self.stage == ZERO_OPTIMIZATION_WEIGHTS else
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT) ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT)
self.allgather_partitions = get_scalar_param( self.allgather_partitions = get_scalar_param(
...@@ -104,3 +127,37 @@ class DeepSpeedZeroConfig(object): ...@@ -104,3 +127,37 @@ class DeepSpeedZeroConfig(object):
zero_config_dict, zero_config_dict,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT, ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT) ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT)
self.cpu_offload_params = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS,
ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS_DEFAULT)
self.cpu_offload_use_pin_memory = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_CPU_OFFLOAD_USE_PIN_MEMORY,
ZERO_OPTIMIZATION_CPU_OFFLOAD_USE_PIN_MEMORY_DEFAULT)
self.sub_group_size = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_SUB_GROUP_SIZE,
ZERO_OPTIMIZATION_SUB_GROUP_SIZE_DEFAULT)
self.max_live_parameters = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS,
ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS_DEFAULT)
self.max_reuse_distance = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE,
ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE_DEFAULT)
self.prefetch_bucket_size = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE,
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT)
self.param_persistence_threshold = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT)
...@@ -13,14 +13,19 @@ ZeRO optimization should be enabled as: ...@@ -13,14 +13,19 @@ ZeRO optimization should be enabled as:
"session_params": { "session_params": {
"zero_optimization": { "zero_optimization": {
"stage": [0|1|2], "stage": [0|1|2],
"stage3_max_live_parameters" : 1000000000,
"stage3_max_reuse_distance" : 1000000000,
"allgather_partitions": [true|false], "allgather_partitions": [true|false],
"allgather_bucket_size": 500000000, "allgather_bucket_size": 500000000,
"reduce_scatter": [true|false], "reduce_scatter": [true|false],
"contiguous_gradients" : [true|false] "contiguous_gradients" : [true|false]
"overlap_comm": [true|false], "overlap_comm": [true|false],
"reduce_bucket_size": 500000000 "reduce_bucket_size": 500000000,
"load_from_fp32_weights": [true|false] "load_from_fp32_weights": [true|false],
"cpu_offload": [true|false] "cpu_offload": [true|false],
"cpu_offload_params" : [true|false],
"cpu_offload_use_pin_memory": [true|false],
"sub_group_size" : 1000000000000
} }
} }
''' '''
...@@ -30,7 +35,7 @@ ZERO_OPTIMIZATION_DISABLED = 0 ...@@ -30,7 +35,7 @@ ZERO_OPTIMIZATION_DISABLED = 0
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1 ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
ZERO_OPTIMIZATION_GRADIENTS = 2 ZERO_OPTIMIZATION_GRADIENTS = 2
ZERO_OPTIMIZATION_WEIGHTS = 3 ZERO_OPTIMIZATION_WEIGHTS = 3
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_WEIGHTS
ZERO_OPTIMIZATION_STAGE = 'stage' ZERO_OPTIMIZATION_STAGE = 'stage'
ZERO_OPTIMIZATION_STAGE_1 = 'stage_1' ZERO_OPTIMIZATION_STAGE_1 = 'stage_1'
...@@ -47,9 +52,11 @@ ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT = True ...@@ -47,9 +52,11 @@ ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT = True
ZERO_OPTIMIZATION_OVERLAP_COMM = 'overlap_comm' ZERO_OPTIMIZATION_OVERLAP_COMM = 'overlap_comm'
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT = False ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT = False
ZERO3_OPTIMIZATION_OVERLAP_COMM_DEFAULT = True
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS = 'contiguous_gradients' ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS = 'contiguous_gradients'
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT = False ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT = False
ZERO3_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT = False
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE = 'reduce_bucket_size' ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE = 'reduce_bucket_size'
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT = 500000000 ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT = 500000000
...@@ -66,18 +73,65 @@ ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT = False ...@@ -66,18 +73,65 @@ ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT = False
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT = 'elastic_checkpoint' ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT = 'elastic_checkpoint'
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT = True ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT = True
ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS = 'cpu_offload_params'
ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS_DEFAULT = False
ZERO_OPTIMIZATION_CPU_OFFLOAD_USE_PIN_MEMORY = 'cpu_offload_use_pin_memory'
ZERO_OPTIMIZATION_CPU_OFFLOAD_USE_PIN_MEMORY_DEFAULT = False
ZERO_OPTIMIZATION_SUB_GROUP_SIZE = 'sub_group_size'
ZERO_OPTIMIZATION_SUB_GROUP_SIZE_DEFAULT = 1000000000000
#maximum number of parameters per GPU before releasing them
ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS = 'stage3_max_live_parameters'
ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS_DEFAULT = 1000000000
#release a parameter only if the reuse distance is larger than specified
ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE = 'stage3_max_reuse_distance'
ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE_DEFAULT = 1000000000
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE = 'stage3_prefetch_bucket_size'
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT = 50000000
#parameters smaller than the threshold are only communicated once after the
#parameters are updated and are persisted thoughout the trainging
#avoid tons of latency bound communication
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD = 'stage3_param_persistence_threshold'
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT = 100000
ZERO_OPTIMIZATION_DEFAULT = { ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT, ZERO_OPTIMIZATION_STAGE:
ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS: ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT, ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT, ZERO_OPTIMIZATION_REDUCE_SCATTER:
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT, ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE:
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS: ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT, ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE: ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS: ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT,
ZERO_OPTIMIZATION_CPU_OFFLOAD: ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT, ZERO_OPTIMIZATION_CPU_OFFLOAD:
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT: ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT:
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT,
ZERO_OPTIMIZATION_CPU_OFFLOAD:
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT,
ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS:
ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS_DEFAULT,
ZERO_OPTIMIZATION_CPU_OFFLOAD_USE_PIN_MEMORY:
ZERO_OPTIMIZATION_CPU_OFFLOAD_USE_PIN_MEMORY,
ZERO_OPTIMIZATION_SUB_GROUP_SIZE:
ZERO_OPTIMIZATION_SUB_GROUP_SIZE_DEFAULT,
ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS:
ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS_DEFAULT,
ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE:
ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE_DEFAULT,
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE:
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD:
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT
} }
import torch
def print_rank_0(message):
if torch.distributed.get_rank() == 0:
print(message)
class ContiguousMemoryAllocator(object):
def __init__(self, size, dtype, device):
self.buffer = torch.zeros(size, dtype=dtype, device=device)
#address to contiguous size avaialble
self.contiguous_sizes = {}
self.contiguous_sizes[0] = size
#tensor id to its address
self.tensor_addresses = {}
#tensor address to its size
self.tensor_sizes = {}
#tensor address to ids
self.tensor_ids = {}
#id to tensors
self.tensor_map = {}
#id to params. Maps each tensor buffer to list of parameters that uses it
self.id_to_params = {}
self.total_size = size
self.total_free = size
self.largest_contiguous = size
self.max_allocated = 0
self.count = 0
#create a tensor of size from the pre-allocated buffer
#if not enough free space will fail
#if not enough contiguous space, will defragment and allocate
def allocate_tensor(self, size):
free_before = self.total_free
assert size <= self.total_free, "Not enough memory in buffer. Allocation failed"
if self.largest_contiguous < size:
print_rank_0("Needs defragmentation to allocate. Before Defragmentation:")
self.print_allocation(resolution=100)
self._defragment_memory()
#set the param data to the new tensor buffer locations
self._reset_param_data()
print_rank_0("After defragmentation:")
self.print_allocation(resolution=100)
self.total_free = self.total_free - size
allocated = self.total_size - self.total_free
if allocated > self.max_allocated:
self.max_allocated = allocated
tensor_address = self._get_new_tensor_address(size)
ret_tensor = self._get_new_tensor(tensor_address, size)
print_rank_0(
f"Free before allocation {free_before}. Allocating {size}. Free after allocation {self.total_free}. Max allocated {self.max_allocated}"
)
assert self.total_free + size == free_before, "Allcation bookeeping error"
return ret_tensor
#assigns the tensor data to the param data and keeps track of the assignment
#any change the the underlying buffer from defragmentation will cause a
#reassignment of the param data
def assign_to_param(self, tensor, param, numel, shape):
tensor_id = id(tensor)
assert tensor_id in self.tensor_map.keys(), "No such tensor allocated by the allocator."
assert tensor.numel() >= numel, "Assert tensor buffer does is not large enough"
assert not tensor_id in self.id_to_params.keys(), "This tensor has already been assigned to a param"
self.id_to_params[tensor_id] = [param]
replicated_tensor = tensor.narrow(0, 0, numel).view(shape)
param.data = replicated_tensor.data
param.contiguous_tensor_id = tensor_id
#deletes the tensor and frees up the underlying buffer
def release_tensor(self, tensor):
free_before = self.total_free
tensor_id = id(tensor)
tensor_size = tensor.numel()
self._release_tensor(tensor_id)
self._unassign_params(tensor_id)
self.total_free += tensor_size
print_rank_0(
f"Free before release {free_before}. Released {tensor.numel()}. Total free after {self.total_free}."
)
assert self.total_free - tensor_size == free_before, "Release bookeeping error"
def release_tensor_with_id(self, tensor_id):
free_before = self.total_free
assert tensor_id in self.tensor_map.keys(), "Invalid tensor id"
tensor = self.tensor_map[tensor_id]
tensor_size = tensor.numel()
self._release_tensor(tensor_id)
self._unassign_params(tensor_id)
self.total_free += tensor_size
print_rank_0(
f"Free before release {free_before}. Released {tensor.numel()}. Total free after {self.total_free}."
)
assert self.total_free - tensor_size == free_before, "Release bookeeping error"
#shows the current memory allocation at specified resolution
def print_allocation(self, resolution=200):
total_size = self.buffer.numel() * 1.0
empty = []
for addr, size in self.contiguous_sizes.items():
start = int(addr * resolution / total_size)
end = int((addr + size) * resolution / total_size)
empty.extend(range(start, end))
s = ''
for i in range(resolution):
s += '.' if i in empty else '|'
print_rank_0(s)
def max_allocated(self):
return self.max_allocated
#to be called after defragmentation that moves the tensor buffers
#this call reassigns the data of all the parameters using the tensor buffers
def _reset_param_data(self):
for id, tensor in self.tensor_map.items():
for param in self.id_to_params[id]:
param.data = tensor.narrow(0,
0,
param.numel()).view(param.data.shape).data
def _unassign_params(self, tensor_id):
if tensor_id in self.id_to_params.keys():
del self.id_to_params[tensor_id]
def _release_tensor(self, tensor_id):
assert tensor_id in self.tensor_addresses, f"Tensor id {tensor_id} not found"
address = self.tensor_addresses[tensor_id]
contiguous_size = self.tensor_map[tensor_id].numel()
del self.tensor_addresses[tensor_id]
del self.tensor_ids[address]
del self.tensor_map[tensor_id]
del self.tensor_sizes[address]
self._consolidate_address(address, contiguous_size)
self.largest_contiguous = self._largest_contiguous()
def _consolidate_address(self, address, contiguous_size):
#consolidate next buffer
end_address = address + contiguous_size
if end_address in self.contiguous_sizes:
contiguous_size += self.contiguous_sizes[end_address]
del self.contiguous_sizes[end_address]
#consolidate previous buffer
for addr, size in self.contiguous_sizes.items():
if addr + size == address:
del self.contiguous_sizes[addr]
contiguous_size += size
address = addr
break
self.contiguous_sizes[address] = contiguous_size
def _defragment_memory(self):
empty_addresses = sorted(self.contiguous_sizes.keys())
tensor_addresses = sorted(self.tensor_addresses.values())
tensor_index = 0
while tensor_index < len(tensor_addresses):
empty_addr = empty_addresses[0]
empty_size = self.contiguous_sizes[empty_addr]
tensor_addr = tensor_addresses[tensor_index]
tensor_size = self.tensor_sizes[tensor_addr]
tensor_id = self.tensor_ids[tensor_addr]
tensor = self.tensor_map[self.tensor_ids[tensor_addr]]
assert tensor_size == tensor.numel(), \
"Size mismatch. {tensor_size} is allocated at addr {tensor_addr} but tensor size is {tensor.numel()} "
assert empty_addr != tensor_addr, \
f"Cannot have same empty address {empty_addr} and tensor address {tensor_addr}"
if empty_addr < tensor_addr:
if empty_size >= tensor_size:
dest_buffer = self.buffer.narrow(0, empty_addr, tensor_size)
src_buffer = self.buffer.narrow(0, tensor_addr, tensor_size)
dest_buffer.data.copy_(src_buffer.data)
else:
#print_rank_0(f'empty addr : {empty_addr}, empty size {empty_size} tensor addr {tensor_addr} tensor size {tensor_size}')
src_addr = tensor_addr
dest_addr = empty_addr
while src_addr < (tensor_addr + tensor_size):
copy_size = min(empty_size, tensor_addr + tensor_size - src_addr)
dest_buffer = self.buffer.narrow(0, dest_addr, copy_size)
src_buffer = self.buffer.narrow(0, src_addr, copy_size)
dest_buffer.data.copy_(src_buffer.data)
src_addr += copy_size
dest_addr += copy_size
self._replace_old_address_with_new(tensor_id, empty_addr)
tensor_index += 1
else:
tensor_index += 1
empty_addresses = sorted(self.contiguous_sizes.keys())
def _replace_old_address_with_new(self, tensor_id, new_address):
tensor = self.tensor_map[tensor_id]
tensor_size = tensor.numel()
tensor.data = self.buffer.narrow(0, new_address, tensor_size).data
self._release_tensor(tensor_id)
self._mark_as_occupied(new_address, tensor_size)
self.tensor_ids[new_address] = tensor_id
self.tensor_map[tensor_id] = tensor
self.tensor_addresses[tensor_id] = new_address
self.tensor_sizes[new_address] = tensor_size
def _get_new_tensor_address(self, size):
tensor_address = None
for address, contiguous_size in self.contiguous_sizes.items():
if contiguous_size >= size and \
(tensor_address is None or \
contiguous_size < self.contiguous_sizes[tensor_address]):
tensor_address = address
assert tensor_address is not None, "address cannot be None"
return tensor_address
def _get_new_tensor(self, address, size):
available_contiguous_size = self.contiguous_sizes[address]
assert size <= available_contiguous_size, \
f"Tensor numel {size} is large than available contiguous size {available_contiguous_size}"
self.count += 1
new_tensor = self.buffer.narrow(0, address, size)
tensor_id = id(new_tensor)
self.tensor_addresses[tensor_id] = address
self.tensor_sizes[address] = size
self.tensor_ids[address] = tensor_id
self.tensor_map[tensor_id] = new_tensor
self._mark_as_occupied(address, size)
return new_tensor
def _largest_contiguous(self):
if len(self.contiguous_sizes) > 0:
return max([size for _, size in self.contiguous_sizes.items()])
else:
return 0
def _mark_as_occupied(self, address, size):
available_contiguous_size = self.contiguous_sizes[address]
del self.contiguous_sizes[address]
if available_contiguous_size != size:
self.contiguous_sizes[address + size] = available_contiguous_size - size
self.largest_contiguous = self._largest_contiguous()
#Linear Module to use with ZeRO Stage 3 to allow for parameter memory release
#after the module execution during forward
#Instead of saving variables using save_for_backward, we save variable ids
#Allowing us to retrive the variable without creating pointer to it
#Which allows for underlying tensor to be garbage collected
#When partitioned as needed by the Zero Stage 3 optimizer
#TODO instead of patching Linear module, we could patch the ctx.save_for_backward
#ctx.saved_tensors so that this approach works for all nn modules that are built upon
#torch.nn.function. However the issue is that many modules uses C++ implementations
#which does not have pytroch implementation. Eg torch.addmm which acts as a funcitonal
#when implemeted outside of torch.autograd.Function
import math
import torch
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.nn.modules.module import Module
tensor_map = {}
class LinearFunctionForZeroStage3(torch.autograd.Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
#print("In ZeRO Linear Function")
weight_id = id(weight)
bias_id = id(bias)
#ctx.save_for_backward(input, weight, bias)
ctx.save_for_backward(input, torch.tensor(weight_id), torch.tensor(bias_id))
tensor_map[weight_id] = weight
tensor_map[bias_id] = bias
if input.dim() == 2 and bias is not None:
# fused op is marginally faster
ret = torch.addmm(bias, input, weight.t())
else:
output = input.matmul(weight.t())
if bias is not None:
output += bias
ret = output
return ret
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
#input, weight, bias = ctx.saved_tensors
input, weight_id, bias_id = ctx.saved_tensors
weight = tensor_map[weight_id.item()]
bias = tensor_map[bias_id.item()]
grad_input = grad_weight = grad_bias = None
#print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}")
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
#print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
grad_input = grad_output.matmul(weight)
#print(f"Computed grad input {grad_input.shape}")
if ctx.needs_input_grad[1]:
#print("Computing grad weight")
dim = grad_output.dim()
if dim > 2:
grad_weight = grad_output.view(-1,
grad_output.shape[-1]).t().matmul(
input.view(-1,
input.shape[-1]))
else:
grad_weight = grad_output.t().matmul(input)
#print(f"Computed grad weight grad_weight {grad_weight.shape}")
if bias is not None and ctx.needs_input_grad[2]:
#print("Computing grad bias")
grad_bias = grad_output.sum(0)
#print("Done computing grad bias")
#print("needs bias")
#print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}")
return grad_input, grad_weight, grad_bias
class LinearModuleForZeroStage3(Module):
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
The weights are pre-transposed and stored as A^T instead of transposing during each
forward. Memory savings proportional to the parameter size.
Args:
in_features: size of each input sample
out_features: size of each output sample
bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True``
Shape:
- Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
additional dimensions and :math:`H_{in} = \text{in\_features}`
- Output: :math:`(N, *, H_{out})` where all but the last dimension
are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
Attributes:
weight: the learnable weights of the module of shape
:math:`(\text{out\_features}, \text{in\_features})`. The values are
initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
:math:`k = \frac{1}{\text{in\_features}}`
bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in\_features}}`
Examples::
>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
super(LinearModuleForZeroStage3, self).__init__()
print("Building ZeRO module")
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tensor) -> Tensor:
return LinearFunctionForZeroStage3.apply(input, self.weight, self.bias)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features,
self.out_features,
self.bias is not None)
This diff is collapsed.
This diff is collapsed.
import torch
from deepspeed.runtime.zero.contiguous_memory_allocator import ContiguousMemoryAllocator
def test1():
mem = ContiguousMemoryAllocator(1024, torch.half, 'cpu')
mem.print_allocation(resolution=100)
a1 = mem.allocate_tensor(64).mul_(0.0).add_(1.0)
mem.print_allocation(resolution=100)
mem.release_tensor(a1)
mem.print_allocation(resolution=100)
a2 = mem.allocate_tensor(64).mul_(0.0).add_(2.0)
a3 = mem.allocate_tensor(256).mul_(0.0).add_(3.0)
a4 = mem.allocate_tensor(128).mul_(0.0).add_(4.0)
mem.print_allocation(resolution=100)
mem.release_tensor(a3)
mem.print_allocation(resolution=100)
a5 = mem.allocate_tensor(64).mul_(0.0).add_(5.0)
a6 = mem.allocate_tensor(256).mul_(0.0).add_(6.0)
a7 = mem.allocate_tensor(128).mul_(0.0).add_(7.0)
mem.print_allocation(resolution=100)
a8 = mem.allocate_tensor(256).mul_(0.0).add_(8.0)
a9 = mem.allocate_tensor(128).mul_(0.0).add_(9.0)
mem.print_allocation(resolution=100)
mem.release_tensor(a9)
mem.release_tensor(a6)
mem.release_tensor(a2)
mem.release_tensor(a5)
a10 = mem.allocate_tensor(512).mul_(0.0).add_(10.0)
mem.print_allocation(resolution=100)
#print(f"a4:{a4}")
#print(f"a7:{a7}")
#print(f"a8:{a8}")
#print(f"a10:{a10}")
assert (a4.norm() + a7.norm() + a8.norm() + a10.norm()).item() == 474.50, "Test failed"
def test2():
mem = ContiguousMemoryAllocator(512, torch.half, 'cpu')
a1 = mem.allocate_tensor(64).mul_(0.0).add_(1.0)
a2 = mem.allocate_tensor(64).mul_(0.0).add_(2.0)
a3 = mem.allocate_tensor(64).mul_(0.0).add_(3.0)
a4 = mem.allocate_tensor(64).mul_(0.0).add_(4.0)
a5 = mem.allocate_tensor(64).mul_(0.0).add_(5.0)
a6 = mem.allocate_tensor(64).mul_(0.0).add_(6.0)
a7 = mem.allocate_tensor(64).mul_(0.0).add_(7.0)
a8 = mem.allocate_tensor(64).mul_(0.0).add_(8.0)
mem.release_tensor(a2)
mem.release_tensor(a4)
mem.release_tensor(a6)
mem.release_tensor(a8)
mem.print_allocation(resolution=100)
a9 = mem.allocate_tensor(128).mul_(0.0).add_(9.0)
a10 = mem.allocate_tensor(64).mul_(0.0).add_(10.0)
a11 = mem.allocate_tensor(64).mul_(0.0).add_(11.0)
mem.release_tensor(a1)
mem.release_tensor(a5)
mem.print_allocation(resolution=100)
a12 = mem.allocate_tensor(128).mul_(0.0).add_(12.0)
mem.print_allocation(resolution=100)
print(f"a7:{a7}")
print(f"a9:{a9}")
print(f"a10:{a10}")
print(f"a11:{a11}")
print(f"a12:{a12}")
assert (a7.norm() + a9.norm() + a10.norm() + a11.norm() + a12.norm()) == 460.75, "TestFailed"
test1()
test2()
...@@ -39,6 +39,7 @@ except ImportError: ...@@ -39,6 +39,7 @@ except ImportError:
def is_zero_supported_optimizer(optimizer): def is_zero_supported_optimizer(optimizer):
if dist.get_rank() == 0:
print( print(
f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}' f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}'
) )
......
...@@ -119,7 +119,8 @@ RUN apt-get update && \ ...@@ -119,7 +119,8 @@ RUN apt-get update && \
libcupti-dev \ libcupti-dev \
libjpeg-dev \ libjpeg-dev \
libpng-dev \ libpng-dev \
screen screen \
libaio-dev
RUN pip install psutil \ RUN pip install psutil \
yappi \ yappi \
cffi \ cffi \
......
...@@ -232,14 +232,22 @@ Example of ***scheduler*** ...@@ -232,14 +232,22 @@ Example of ***scheduler***
Enabling and configuring ZeRO memory optimizations Enabling and configuring ZeRO memory optimizations
```json ```json
"zero_optimization": { "zero_optimization": {
"stage": [0|1|2], "stage": [0|1|2|3],
"allgather_partitions": [true|false], "allgather_partitions": [true|false],
"allgather_bucket_size": 5e8, "allgather_bucket_size": 5e8,
"overlap_comm": false, "overlap_comm": false,
"reduce_scatter": [true|false], "reduce_scatter": [true|false],
"reduce_bucket_size": 5e8, "reduce_bucket_size": 5e8,
"contiguous_gradients" : [true|false], "contiguous_gradients" : [true|false],
"cpu_offload": [true|false] "cpu_offload": [true|false],
"cpu_offload_params" : [true|false],
"cpu_offload_use_pin_memory" : [true|false],
"stage3_max_live_parameters" : 1e9,
"stage3_max_reuse_distance" : 1e9,
"stage3_prefetch_bucket_size" : 5e8,
"stage3_param_persistence_threshold" : 1e6,
"sub_group_size" : 1e12,
"elastic_checkpoint" : [true|false]
} }
``` ```
...@@ -253,7 +261,7 @@ Enabling and configuring ZeRO memory optimizations ...@@ -253,7 +261,7 @@ Enabling and configuring ZeRO memory optimizations
| Description | Default | | Description | Default |
| --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Chooses different stages of ZeRO Optimizer. Stage 0, 1, and 2 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitiong, respectively. | `0` | | Chooses different stages of ZeRO Optimizer. Stage 0, 1, 2, and 3 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitioning, and optimizer+gradient+parameter partitioning, respectively. | `0` |
***allgather_partitions***: [boolean] ***allgather_partitions***: [boolean]
...@@ -297,6 +305,42 @@ Enabling and configuring ZeRO memory optimizations ...@@ -297,6 +305,42 @@ Enabling and configuring ZeRO memory optimizations
| ------------------------------------------------------------------------------------------------------------------------ | ------- | | ------------------------------------------------------------------------------------------------------------------------ | ------- |
| Enable offloading of optimizer memory and computation to CPU. This frees up GPU memory for larger models or batch sizes. | `False` | | Enable offloading of optimizer memory and computation to CPU. This frees up GPU memory for larger models or batch sizes. | `False` |
***cpu_offload_params***: [boolean]
| Description | Default |
| --------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Enable offloading of model parameters to CPU. This frees up GPU memory for larger models or batch sizes. Valid only with stage 3. | `False` |
***cpu_offload_use_pin_memory***: [boolean]
| Description | Default |
| ----------------------------------------------------------------------------------------- | ------- |
| Use pinned CPU memory when offloading. Can improve performance. Valid only with stage 3. | `False` |
***stage3_max_live_parameters***: [integer]
| Description | Default |
| ------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| The maximum number of parameters resident per GPU before releasing. Smaller values use less memory, but perform more communication. | `1e9` |
***stage3_max_reuse_distance***: [integer]
| Description | Default |
| ---------------------------------------------------------------------------------------------------------------- | ------- |
| Do not release a parameter if it will be reused within this threshold of parameters. Smaller values use less memory, but perform more communication. | `1e9` |
***stage3_prefetch_bucket_size***: [integer]
| Description | Default |
| ------------------------------------------------------------------------------------------------------------------------------- | ------- |
| The size of the fixed buffer for prefetching parameters. Smaller values use less memory, but can increase stalls due to communication. | `5e8` |
***stage3_param_persistence_threshold***: [integer]
| Description | Default |
| -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages). | `1e6` |
### Logging ### Logging
......
...@@ -3,7 +3,7 @@ title: "Zero Redundancy Optimizer (ZeRO)" ...@@ -3,7 +3,7 @@ title: "Zero Redundancy Optimizer (ZeRO)"
--- ---
If you have not done so already, we advise that you read the DeepSpeed tutorials on [Getting Started](/getting-started/) and [Megatron-LM GPT-2](/tutorials/megatron/) before stepping through this tutorial. If you have not done so already, we advise that you read the DeepSpeed tutorials on [Getting Started](/getting-started/) and [Megatron-LM GPT-2](/tutorials/megatron/) before stepping through this tutorial.
In this tutorial, we will apply the ZeRO optimizer to the [Megatron-LM GPT-2](https://github.com/NVIDIA/Megatron-LM) model. ZeRO is a powerful set of memory optimization techniques that enable effective FP16 training of large models with billions of parameters, such as [GPT-2](https://openai.com/blog/better-language-models/) and [Turing-NLG 17B](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/). Compared to the alternative model parallelism approaches for training large models, a key appeal of ZeRO is that no model code modifications are required. As this tutorial will demonstrate, *using ZeRO in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration json*. No code changes are needed. In this tutorial, we will apply the ZeRO optimizer to the [Megatron-LM GPT-2](https://github.com/NVIDIA/Megatron-LM) model. ZeRO is a powerful set of memory optimization techniques that enable effective FP16 training of large models with trillons of parameters, such as [GPT-2](https://openai.com/blog/better-language-models/) and [Turing-NLG 17B](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/). Compared to the alternative model parallelism approaches for training large models, a key appeal of ZeRO is that no model code modifications are required. As this tutorial will demonstrate, *using ZeRO in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration JSON*. No code changes are needed.
## ZeRO Overview ## ZeRO Overview
ZeRO leverages the aggregate computation and memory resources of data parallelism to reduce the memory and compute requirements of each device (GPU) used for model training. ZeRO reduces the memory consumption of each GPU by partitioning the various model training states (weights, gradients, and optimizer states) across the available devices (GPUs and CPUs) in the distributed training hardware. Concretely, ZeRO is being implemented as incremental stages of optimizations, where optimizations in earlier stages are available in the later stages. To deep dive into ZeRO, please see our [paper](https://arxiv.org/abs/1910.02054v3). ZeRO leverages the aggregate computation and memory resources of data parallelism to reduce the memory and compute requirements of each device (GPU) used for model training. ZeRO reduces the memory consumption of each GPU by partitioning the various model training states (weights, gradients, and optimizer states) across the available devices (GPUs and CPUs) in the distributed training hardware. Concretely, ZeRO is being implemented as incremental stages of optimizations, where optimizations in earlier stages are available in the later stages. To deep dive into ZeRO, please see our [paper](https://arxiv.org/abs/1910.02054v3).
...@@ -12,11 +12,13 @@ ZeRO leverages the aggregate computation and memory resources of data parallelis ...@@ -12,11 +12,13 @@ ZeRO leverages the aggregate computation and memory resources of data parallelis
* **Stage 2**: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states. * **Stage 2**: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states.
* **Stage 3**: The 16-bit model parameters are partitioned across the processes. ZeRO will automatically collect and partition them during the forward and backward passes.
## Training environment ## Training environment
We use the DeepSpeed [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM) GPT-2 code for this exercise. You can step through the Megatron-LM [tutorial](/tutorials/megatron/) to familiarize yourself with the code. We will train the models in this tutorial on [NVIDIA Tesla V100-SXM3 Tensor Core GPUs](https://www.nvidia.com/en-us/data-center/v100/) with 32GB RAM. We use the DeepSpeed [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM) GPT-2 code for this exercise. You can step through the Megatron-LM [tutorial](/tutorials/megatron/) to familiarize yourself with the code. We will train the models in this tutorial on [NVIDIA Tesla V100-SXM3 Tensor Core GPUs](https://www.nvidia.com/en-us/data-center/v100/) with 32GB RAM.
## Enabling ZeRO Optimization ## Enabling ZeRO Optimization
To enable ZeRO optimizations for a DeepSpeed model, we simply add the **_zero_optimization_** key to the DeepSpeed json configuration. A full description of configuration knobs of the **zero_optimization** key is available [here](/docs/config-json/#zero-optimizations-for-fp16-training). To enable ZeRO optimizations for a DeepSpeed model, we simply add the **_zero_optimization_** key to the DeepSpeed JSON configuration. A full description of configuration knobs of the **zero_optimization** key is available [here](/docs/config-json/#zero-optimizations-for-fp16-training).
### Training a 1.5B Parameter GPT-2 model ### Training a 1.5B Parameter GPT-2 model
We demonstrate the benefits of ZeRO stage 1 by showing that it enables data parallel training of a 1.5 billion parameter GPT-2 model on eight V100 GPUs. We configure training to use a batch size of 1 per device to ensure that the memory consumption is primarily due to model parameters and optimizer states. We create this training scenario by applying the following modifications to the deepspeed launch script: We demonstrate the benefits of ZeRO stage 1 by showing that it enables data parallel training of a 1.5 billion parameter GPT-2 model on eight V100 GPUs. We configure training to use a batch size of 1 per device to ensure that the memory consumption is primarily due to model parameters and optimizer states. We create this training scenario by applying the following modifications to the deepspeed launch script:
...@@ -36,7 +38,7 @@ Training this model without ZeRO fails with an out-of-memory (OOM) error as show ...@@ -36,7 +38,7 @@ Training this model without ZeRO fails with an out-of-memory (OOM) error as show
<img src="/assets/images/oom_dp8_1.5B_log.png"> <img src="/assets/images/oom_dp8_1.5B_log.png">
</a> </a>
A key reason why this model does not fit in GPU memory is that the Adam optimizer states for the model consume 18GB; a significant portion of the 32GB RAM. By using ZeRO stage 1 to partition the optimizer state among eight data parallel ranks, the per-device memory consumption can be reduced to 2.25GB, thus making the model trainable. To enable ZeRO stage 1, we simply update the DeepSpeed json config file as below: A key reason why this model does not fit in GPU memory is that the Adam optimizer states for the model consume 18GB; a significant portion of the 32GB RAM. By using ZeRO stage 1 to partition the optimizer state among eight data parallel ranks, the per-device memory consumption can be reduced to 2.25GB, thus making the model trainable. To enable ZeRO stage 1, we simply update the DeepSpeed JSON config file as below:
```json ```json
{ {
...@@ -75,7 +77,7 @@ First, we need to configure a 10B parameter model with activation checkpointing ...@@ -75,7 +77,7 @@ First, we need to configure a 10B parameter model with activation checkpointing
--checkpoint-activations --checkpoint-activations
``` ```
Next, we need to update the DeepSpeed json configuration, as shown below, to enable ZeRO stage 2 optimizations: Next, we need to update the DeepSpeed JSON configuration, as shown below, to enable ZeRO stage 2 optimizations:
```json ```json
{ {
...@@ -104,4 +106,159 @@ Here is a screenshot of nvidia-smi showing GPU activity during training: ...@@ -104,4 +106,159 @@ Here is a screenshot of nvidia-smi showing GPU activity during training:
<img src="/assets/images/zero2_dp32_10B_smi.png"> <img src="/assets/images/zero2_dp32_10B_smi.png">
</a> </a>
### Training trillion-scale models with ZeRO-3 Offload
Stage 3 can be enabled in the JSON configuration. A full description of these
configurations is available [here](/docs/config-json/#zero-optimizations-for-fp16-training).
```json
{
"zero_optimization": {
"stage": 3,
"cpu_offload": true,
"cpu_offload_params": true,
"overlap_comm": true,
"contiguous_gradients": true,
"stage3_max_live_parameters": 6000000,
"stage3_max_reuse_distance": 100000000,
"stage3_prefetch_bucket_size": 200000,
"stage3_param_persitance_threshold": 100000,
"reduce_bucket_size": 3000000,
"sub_group_size": 1e6
}
}
```
ZeRO-3 will automatically collect and partition the parameters as they are
needed during the forward and backward passes. However, in some cases a
parameter may be used outside of its module's forward pass. We call these
*external parameters*. ZeRO-3 can coordinate these parameters if they are
registered. Please see our [ZeRO-3 docs](https://deepspeed.readthedocs.io/en/latest/zero3.html) for more
information and examples of external parameters.
The Megatron-LM model has three external parameters that must be registered
with ZeRO-3. External parameters are those that are accessed outside of the
owning module's forward pass.
1. `megatron/model/gpt2_model.py:GPT2Model`: register the word embedding for both uses in forward.
```python
class GPT2Model(MegatronModule):
def __init__(self, num_tokentypes=0, parallel_output=True):
...
deepspeed.zero.register_external_parameter(self,
self.language_model.embedding.word_embeddings.weight)
def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
# self.embeddings will compute its forward pass here
lm_output = self.language_model(input_ids,
position_ids,
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value)
...
# Accesses word_embeddings.weight outside of the embedding's forward pass.
output = parallel_lm_logits(
lm_output,
self.language_model.embedding.word_embeddings.weight,
parallel_output)
```
2. `megatron/model/transformer.py:ParallelMLP`: register a bias that is
returned from a submodule forward and used in this forward.
```python
class ParallelMLP(MegatronModule):
def __init__(self, init_method, output_layer_init_method):
...
if self.dense_h_to_4h.bias is not None:
deepspeed.zero.register_external_parameter(self, self.dense_h_to_4h.bias)
def forward(self, hidden_states):
# bias_parallel is a parameter of dense_h_to_4h
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
...
```
3. `megatron/model/transformer.py:ParallelTransformerLayer`: register two biases that
are returned from submodules and used in forward.
```python
class ParallelTransformerLayer(MegatronModule):
...
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number):
...
if self.attention.dense.bias is not None:
deepspeed.zero.register_external_parameter(self, self.attention.dense.bias)
if self.mlp.dense_4h_to_h.bias is not None:
deepspeed.zero.register_external_parameter(self, self.mlp.dense_4h_to_h.bias)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
...
# attention_bias is a parameter returned from attention
# Self attention.
attention_output, attention_bias = \
self.attention(layernorm_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
...
# mlp_bias is a parameter returned from mlp
mlp_output, mlp_bias = self.mlp(layernorm_output)
...
```
#### Allocating Massive Megatron-LM Models
We make two further changes to model initalization in order to support models
that exceed *local* system memory, but not not *total* system memory.
1. Allocate the model in a memory-scalable fashion. The model parameters will
be allocated and immediately partitioned across the data parallel group. If
`remote_device="cpu"`, the model will also be allocated in CPU memory
instead of GPU memory. Please see the full
[ZeRO-3 Init docs](https://deepspeed.readthedocs.io/en/latest/zero3.html#deepspeed.zero.Init)
for more details.
```python
with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
remote_device=get_args().remote_device,
enabled=get_args().zero_stage==3):
model = GPT2Model(num_tokentypes=0, parallel_output=True)
```
2. Gather the position embeddings weight for initialization. DeepSpeed will automatically
gather a module's parameters during its constructor and for its forward and backward pass.
However, additional accesses must coordinate with DeepSpeed to ensure that parameter data
is gathered and subsequently partitioned. If the tensor is modified, the `modifier_rank`
argument should also be used to ensure all ranks have a consistent view of
the data. Please see the full
[GatheredParameters docs](https://deepspeed.readthedocs.io/en/latest/zero3.html#deepspeed.zero.GatheredParameters)
for more details.
```python
self.position_embeddings = torch.nn.Embedding(...)
with deepspeed.zero.GatheredParameters(self.position_embeddings.weight,
modifier_rank=0):
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
```
Congratulations! You have completed the ZeRO tutorial. Congratulations! You have completed the ZeRO tutorial.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment