Unverified Commit f2ac7eaf authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

ZeRO-2 (#217)



Updates for ZeRO stage 2 + ZeRO stage 1 w. RS
Co-authored-by: default avatarTunji Ruwase <olruwase@microsoft.com>
Co-authored-by: default avatarSamyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: default avatarShaden Smith <ShadenTSmith@gmail.com>
Co-authored-by: default avatarElton Zheng <eltonz@microsoft.com>
Co-authored-by: default avatarShaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: default avataryuxionghe <yuxhe@microsoft.com>
Co-authored-by: default avatarArash Ashari <arashari@microsoft.com>
parent c61e23b4
......@@ -2,6 +2,7 @@
.idea/
*~
*.swp
*.log
deepspeed/git_version_info.py
# Build + installation data
......
Subproject commit 9e2c735f5aabe48395c03a276fa7a0c51f6d3025
Subproject commit 274787a189b265814ed75dd5ddeae2dce026ea88
[![Build Status](https://dev.azure.com/DeepSpeedMSFT/DeepSpeed/_apis/build/status/microsoft.DeepSpeed?branchName=master)](https://dev.azure.com/DeepSpeedMSFT/DeepSpeed/_build/latest?definitionId=1&branchName=master)
[![License MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/Microsoft/DeepSpeed/blob/master/LICENSE)
[DeepSpeed](https://www.deepspeed.ai/) is a deep learning optimization library that makes distributed training easy,
efficient, and effective.
[DeepSpeed](https://www.deepspeed.ai/) is a deep learning optimization
library that makes distributed training easy, efficient, and effective.
<p align="center"><i><b>10x Larger Models</b></i></p>
<p align="center"><i><b>5x Faster Training</b></i></p>
<p align="center"><i><b>10x Faster Training</b></i></p>
<p align="center"><i><b>Minimal Code Change</b></i></p>
DeepSpeed can train deep learning models with over a hundred billion parameters on current
generation of GPU clusters, while achieving over 5x in system performance
generation of GPU clusters, while achieving over 10x in system performance
compared to the state-of-art. Early adopters of DeepSpeed have already produced
a language model (LM) with over 17B parameters called
[Turing-NLG](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft),
establishing a new SOTA in the LM category.
**_For further documentation, tutorials, and technical deep-dives please see [deepspeed.ai](https://www.deepspeed.ai/)!_**
# News
* [Turing-NLG: A 17-billion-parameter language model by Microsoft](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/)
* [ZeRO & DeepSpeed: New system optimizations enable training models with over 100 billion parameters](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/)
* [2020/05/19] [ZeRO-2 empowers training models as large as 170 billion parameters up to 10x faster compared to state-of-the-art](https://www.deepspeed.ai/news/2020/05/19/zero-stage2.html)
<span style="color:dodgerblue">**[_NEW_]**</span>
* [2020/05/19] [DeepSpeed optimizes transformer kernels to achieve world’s fastest BERT training record: 44 minutes on 1024 NVIDIA V100 GPUs](https://www.deepspeed.ai/news/2020/05/19/bert-record.html)
<span style="color:dodgerblue">**[_NEW_]**</span>
* [2020/02/13] [Turing-NLG: A 17-billion-parameter language model by Microsoft](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/)
* [2020/02/13] [ZeRO & DeepSpeed: New system optimizations enable training models with over 100 billion parameters](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/)
# Table of Contents
......@@ -39,93 +46,6 @@ a large model easily runs out of memory with pure data parallelism and it is
difficult to use model parallelism. DeepSpeed addresses these challenges to
accelerate model development *and* training.
## Distributed, Effective, and Efficient Training with Ease
The DeepSpeed API is a lightweight wrapper on [PyTorch](https://pytorch.org/). This
means that you can use everything you love in PyTorch and without learning a new
platform. In addition, DeepSpeed manages all of the boilerplate state-of-the-art
training techniques, such as distributed training, mixed precision, gradient
accumulation, and checkpoints so that you can focus on your model development. Most
importantly, you can leverage the distinctive efficiency and effectiveness benefit of
DeepSpeed to boost speed and scale with just a few lines of code changes to your PyTorch
models.
## Speed
DeepSpeed achieves high performance and fast convergence through a combination of
efficiency optimizations on compute/communication/memory/IO and effectiveness
optimizations on advanced hyperparameter tuning and optimizers. For example:
* DeepSpeed trains BERT-large to parity in 14 hours using 64 GPUs (4 DGX-2 boxes) and in
3.7 hours using 256 GPUs (16 DGX-2 boxes).
**BERT-large Training Times**
| Devices | Source | Training Time (hours) |
| ------------- | --------- | ---------------------:|
| 64 TPUs | Google | 96 |
| 64 V100 GPUs | DeepSpeed | **14** |
| 256 V100 GPUs | NVIDIA | 3.9 |
| 256 V100 GPUs | DeepSpeed | **3.7** |
*Read more*: [BERT pre-training tutorial](https://www.deepspeed.ai/tutorials/bert-pretraining/)
* DeepSpeed trains GPT2 (1.5 billion parameters) 3.75x faster than state-of-art, NVIDIA
Megatron on Azure GPUs.
*Read more*: [GPT tutorial](https://www.deepspeed.ai/tutorials/megatron/)
## Memory efficiency
DeepSpeed provides memory-efficient data parallelism and enables training models without
model parallelism. For example, DeepSpeed can train models with up to 6 billion parameters on
NVIDIA V100 GPUs with 32GB of device memory. In comparison, existing frameworks (e.g.,
PyTorch's Distributed Data Parallel) run out of memory with 1.5 billion parameter models.
DeepSpeed reduces the training memory footprint through a novel solution called Zero
Redundancy Optimizer (ZeRO). Unlike basic data parallelism where memory states are
replicated across data-parallel processes, ZeRO partitions model states to save
significant memory. The current implementation (stage 1 of ZeRO) reduces memory by up to
4x relative to the state-of-art. You can read more about ZeRO in our [paper](https://arxiv.org/abs/1910.02054).
With this impressive memory reduction, early adopters of DeepSpeed have already
produced a language model (LM) with over 17B parameters called
[Turing-NLG](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft),
establishing a new SOTA in the LM category.
## Scalability
DeepSpeed supports efficient data parallelism, model parallelism, and their
combination. ZeRO boosts the scaling capability and efficiency further.
* DeepSpeed provides system support to run models up to 100 billion parameters,
10x larger than the state-of-art (8 billion NVIDIA GPT, 11 billion Google T5).
* DeepSpeed can run large models more efficiently, up to 6x faster for models with
various sizes spanning 1.5B to 100B. More specifically, the data parallelism powered by ZeRO
is complementary and can be combined with different types of model parallelism. It allows
DeepSpeed to fit models using lower degree of model parallelism and higher batch size, offering
significant performance gains compared to using model parallelism alone.
*Read more*: [technical report](https://arxiv.org/abs/1910.02054)
and [GPT tutorial](https://www.deepspeed.ai/tutorials/megatron/)
![DeepSpeed-vs-Megatron](./docs/assets/images/DeepSpeed-vs-Megatron.png)
<p align="center">
<em>The figure depicts system throughput improvements of DeepSpeed (combining ZeRO-powered data parallelism with model parallelism of NVIDIA Megatron-LM) over using Megatron-LM alone.</em>
</p>
## Fast convergence for effectiveness
DeepSpeed supports advanced hyperparameter tuning and large batch size
optimizers such as [LAMB](https://arxiv.org/abs/1904.00962). These improve the
effectiveness of model training and reduce the number of samples required to
convergence to desired accuracy.
*Read more*: [Tuning tutorial](https://www.deepspeed.ai/tutorials/1Cycle/) and [BERT pre-training tutorial](https://www.deepspeed.ai/tutorials/bert-pretraining/)
## Usability
Only a few lines of code changes are needed to enable a PyTorch model to use DeepSpeed and ZeRO. Compared to current model parallelism libraries, DeepSpeed does not require a code redesign or model refactoring. It also does not put limitations on model dimensions (such as number of attention heads, hidden sizes, and others), batch size, or any other training parameters. For models of up to six billion parameters, you can use ZeRO-powered data parallelism conveniently without requiring model parallelism, while in contrast, standard data parallelism will run out of memory for models with more than 1.3 billion parameters. In addition, DeepSpeed conveniently supports flexible combination of ZeRO-powered data parallelism with custom model parallelisms, such as tensor slicing of NVIDIA's Megatron-LM.
# Features
Below we provide a brief feature list, see our detailed [feature
......
......@@ -35,11 +35,6 @@ jobs:
pre-commit run --all-files
displayName: 'Formatting checks'
- script: |
pip install --user pylint
pylint --exit-zero deepspeed/
displayName: 'Code linter'
- script: |
pytest --forked --verbose tests/unit/
displayName: 'Unit tests'
......
......@@ -6,6 +6,8 @@ from deepspeed.pt.deepspeed_light import DeepSpeedLight
from deepspeed.pt.deepspeed_light import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from deepspeed.pt.deepspeed_lr_schedules import add_tuning_arguments
import deepspeed.pt.deepspeed_checkpointing as checkpointing
try:
from deepspeed.git_version_info import git_hash, git_branch
except ImportError:
......@@ -14,7 +16,7 @@ except ImportError:
# Export version information
__version_major__ = 0
__version_minor__ = 1
__version_minor__ = 2
__version_patch__ = 0
__version__ = '.'.join(
map(str,
......@@ -33,7 +35,8 @@ def initialize(args,
lr_scheduler=None,
mpu=None,
dist_init_required=None,
collate_fn=None):
collate_fn=None,
config_params=None):
"""Initialize the DeepSpeed Engine.
Arguments:
......@@ -91,7 +94,8 @@ def initialize(args,
lr_scheduler=lr_scheduler,
mpu=mpu,
dist_init_required=dist_init_required,
collate_fn=collate_fn)
collate_fn=collate_fn,
config_params=config_params)
return_items = [
engine,
......
'''
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
Use to partition the activations stored for backward propagation
Therefore reduces the memory consumption
Also implements CPU checkpointing and contiguous memory checkpointing
Reduces memory consumption and memory fragmentation
Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py
b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
'''
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import contextlib
import torch.distributed as dist
import torch
from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from deepspeed.pt.deepspeed_timer import SynchronizedWallClockTimer as Timers
import torch.distributed as dist
from deepspeed.pt.deepspeed_config import DeepSpeedConfig
#DeepSpeed Checkpointing Enabled or Disabled
deepspeed_checkpointing_enabled = False
#MP parameters
mpu = None
mp_rank = None
mp_size = None
mp_group = None
#Model Parameters
num_layers = None
#Checkpointing buffers
contiguous_data_buffers = []
data_offsets = []
contiguous_size_buffers = []
size_offsets = []
timers = None
#optimization flags
PARTITION_ACTIVATIONS = False
PA_TO_CPU = False
CONTIGUOUS_CHECKPOINTING = False
SYNCHRONIZE = False
PROFILE_TIME = False
def see_memory_usage(message, force=False):
#return
if not force:
return
#dist.barrier()
if dist.get_rank() == 0:
print(message)
print("Memory Allocated ",
torch.cuda.memory_allocated() / (1024 * 1024 * 1024),
"GigaBytes")
print("Max Memory Allocated ",
torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),
"GigaBytes")
print("Cache Allocated ",
torch.cuda.memory_cached() / (1024 * 1024 * 1024),
"GigaBytes")
print("Max cache Allocated ",
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
"GigaBytes")
print(" ")
#input("Press Any Key To Continue ..")
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
transport_stream = None
cuda_device = None
def detach_variable(inputs, device=None):
if isinstance(inputs, tuple):
out = []
for inp in inputs:
if not isinstance(inp, torch.Tensor):
out.append(inp)
continue
requires_grad = inp.requires_grad
if device is not None:
x = inp.to(device=device)
else:
x = inp
x = x.detach()
x.requires_grad = requires_grad
out.append(x)
return tuple(out)
else:
raise RuntimeError(
"Only tuple of tensors is supported. Got Unsupported input type: ",
type(inputs).__name__)
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Argumentss:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
# older PyTorch
def cb():
with device_ctx_manager(device):
_C._cuda_setRNGState(new_state)
else:
# newer PyTorch
if device == -1:
device = torch.device('cuda')
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device('cuda', device)
def cb():
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
_lazy_call(cb)
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Using the `add` method, a cuda rng state is initialized based on
the input `seed` and is assigned to `name`. Later, by forking the
rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self):
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
def reset(self):
"""Set to the initial state (no tracker)."""
self.states_ = {}
self.seeds_ = set()
def get_states(self):
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states(self, states):
"""Set the rng states. For efficiency purposes, we do not check
the size of seed for compatibility."""
self.states_ = states
def add(self, name, seed):
"""Track the rng state."""
# Check seed is not already used.
if seed in self.seeds_:
raise Exception('seed {} already exists'.format(seed))
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise Exception('cuda rng state {} already exists'.format(name))
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
"""Fork the cuda rng state, perform operations, and exit with
the original state."""
# Check if we have added the state
if name not in self.states_:
raise Exception('cuda rng state {} is not added'.format(name))
# Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
# RNG tracker object.
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
def get_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _CUDA_RNG_STATE_TRACKER
def model_parallel_cuda_manual_seed(seed):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
initialized. Also, no torch.cuda.manual_seed should be called
after this function. Basically, this is replacement for that
function.
Two set of RNG states are tracked:
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
different model paralle groups. This is used for
example for dropout in the non-model-parallel regions.
model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
model parallel regions.
"""
global mpu
# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
model_parallel_seed = offset + mpu.get_model_parallel_rank()
# Data parallel gets the original sedd.
data_parallel_seed = seed
if torch.distributed.get_rank() == 0:
print('> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format(
torch.distributed.get_rank(),
mpu.get_model_parallel_rank(),
mpu.get_data_parallel_rank(),
model_parallel_seed,
data_parallel_seed),
flush=True)
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed)
def get_partition_start(item):
global mp_rank, mp_size, mp_group
size = item.numel()
partition_size = size / mp_size
start = partition_size * mp_rank
return int(start)
def get_partition_size(item):
global mp_rank, mp_size, mp_group
size = item.numel()
assert size % mp_size == 0, "Doesn't handle if partition activation if item is not divisible by mp size"
partition_size = size / mp_size
return int(partition_size)
def get_full_inputs(tensors, device=None):
inputs = []
num_args = int(len(tensors) / 2)
for i in range(num_args - 1):
item = tensors[2 * i]
size = tensors[2 * i + 1]
partition_size = item.numel()
tensor_size = partition_size * mp_size
if device is not None:
flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device)
else:
flat_tensor = torch.zeros([tensor_size],
dtype=item.dtype,
device=item.device)
partitions = []
for i in range(mp_size):
part_i = flat_tensor.narrow(0, partition_size * i, partition_size)
if i == mp_rank:
part_i.copy_(item)
partitions.append(part_i)
if mp_group is not None:
dist.all_gather(partitions, partitions[mp_rank], group=mp_group)
input_tensor = flat_tensor.view(list(size.numpy()))
item.data = input_tensor.data
inputs.append(item)
inputs.append(tensors[-2])
return tuple(inputs)
class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
2) the states in the model parallel tracker are also properly
tracked/set/reset.
3) Performance activation partitioning, contiguous memory optimization
4) CPU Checkpointing
5) Profile forward and backward functions
"""
@staticmethod
def forward(ctx, run_function, *args):
global mpu, timers, SYNCHRONIZE, PROFILE_TIME
if SYNCHRONIZE:
torch.cuda.synchronize()
if timers is None and PROFILE_TIME:
timers = Timers()
if PROFILE_TIME:
timers('forward').start()
ctx.run_function = run_function
global num_layers
global mp_rank, mp_size, mp_group
global contiguous_data_buffers, contiguous_size_buffers
global data_offsets, size_offsets
if mp_rank is None:
if mpu is not None:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()
else:
mp_rank = 0
mp_size = 1
mp_group = None
global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset
if cuda_device is None:
see_memory_usage("First Forward Begining", force=True)
if dist.get_rank() == 0:
print(f"Activation Checkpointing Information")
print(
f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}"
)
print(
f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers"
)
print(f"----Synchronization {SYNCHRONIZE}")
print(f"----Profiling {PROFILE_TIME}")
cuda_device = torch.cuda.current_device()
transport_stream = torch.cuda.Stream(device=cuda_device)
if PARTITION_ACTIVATIONS:
#inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]]
#inputs.append(args[-1])
inputs = []
for i, item in enumerate(args[:-1]):
partition_size = get_partition_size(item)
partition = item.detach().contiguous().view(-1).narrow(
0,
get_partition_start(item),
partition_size).clone()
if CONTIGUOUS_CHECKPOINTING:
buffer_device = torch.device(
'cpu') if PA_TO_CPU else partition.device
if i >= len(contiguous_data_buffers):
tensor_list = [
torch.tensor(()).new_empty([partition_size],
dtype=partition.dtype,
device=buffer_device)
for i in range(num_layers)
]
contiguous_data_buffers.append(tensor_list)
data_offsets.append(0)
elif contiguous_data_buffers[i] is None:
tensor_list = [
torch.tensor(()).new_empty([partition_size],
dtype=partition.dtype,
device=buffer_device)
for i in range(num_layers)
]
contiguous_data_buffers[i] = tensor_list
data_offsets[i] = 0
contiguous_partition = contiguous_data_buffers[i][
data_offsets[i]].data.copy_(partition.data)
data_offsets[i] = data_offsets[i] + 1
inputs.append(contiguous_partition)
else:
partition = partition.cpu() if PA_TO_CPU else partition
inputs.append(partition)
inputs.append(args[-1])
#just in case something funky is happening such as reuse of inputs
inputs_cuda = [item.to(cuda_device) for item in args]
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
#ctx.save_for_backward(*args)
with torch.no_grad():
outputs = run_function(*inputs_cuda)
del inputs_cuda
#with torch.cuda.stream(transport_stream):
#if PARTITION_ACTIVATIONS:
# new_args = []
# for arg, inp in zip(args,inputs):
# size= torch.tensor(arg.size())
# arg.data = inp.data
# new_args.append(arg)
# new_args.append(size)
# ctx.save_for_backward(*new_args)
if PARTITION_ACTIVATIONS:
new_args = []
for i, (arg, inp) in enumerate(zip(args, inputs)):
size = torch.tensor(arg.size())
arg.data = inp.data
new_args.append(arg)
if CONTIGUOUS_CHECKPOINTING:
numel = size.numel()
if i >= len(contiguous_size_buffers):
tmp = torch.tensor(())
contiguous_size_buffers.append(
tmp.new_empty([numel * num_layers],
dtype=size.dtype,
device=size.device))
size_offsets.append(0)
elif contiguous_size_buffers[i] is None:
tmp = torch.tensor(())
contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers],
dtype=size.dtype,
device=size.device)
size_offsets[i] = 0
contiguous_size = contiguous_size_buffers[i].narrow(
0,
size_offsets[i],
numel).data.copy_(size.data)
contiguous_size = contiguous_size.view_as(size)
size_offsets[i] = size_offsets[i] + numel
new_args.append(contiguous_size)
else:
new_args.append(size)
#if dist.get_rank() == 0:
# print (f"The stored tensor is {contiguous_size} and orginal one is {size} ")
ctx.save_for_backward(*new_args)
else:
ctx.save_for_backward(*args)
if PROFILE_TIME:
timers('forward').stop()
timers.log(['forward'])
if SYNCHRONIZE:
torch.cuda.synchronize()
return outputs
@staticmethod
def backward(ctx, *args):
global timers
#see_memory_usage("In backward", force=True)
#removing pointers to the contiguous buffer memory
#so that they can be garbage collected once the checkpoints
#have been used
if SYNCHRONIZE:
torch.cuda.synchronize()
if PROFILE_TIME:
timers('backward').start()
if CONTIGUOUS_CHECKPOINTING:
global data_offsets, size_offsets
global contiguous_data_buffers, contiguous_size_buffers
for buffers in contiguous_data_buffers:
buffers = []
#frees up all the pointers to the checkpoints except for the ones
#stored by save for backward
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
size_offsets = []
#see_memory_usage("In backward checkpointing code", force=True)
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
global cuda_device, transport_stream, PARTITION_ACTIVATIONS
if PARTITION_ACTIVATIONS:
#with torch.cuda.stream(transport_stream):
inputs = get_full_inputs(ctx.saved_tensors,
device=cuda_device if PA_TO_CPU else None)
detached_inputs = detach_variable(inputs)
else:
inputs = ctx.saved_tensors
detached_inputs = detach_variable(inputs)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
# if PARTITION_ACTIVATIONS:
# current_stream=torch.cuda.current_stream()
# current_stream.wait_stream(transport_stream)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs, )
torch.autograd.backward(outputs, args)
if PROFILE_TIME:
timers('backward').stop()
timers.log(['backward'])
if SYNCHRONIZE:
torch.cuda.synchronize()
return (None, ) + tuple(inp.grad for inp in detached_inputs)
def checkpoint(function, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint. """
return CheckpointFunction.apply(function, *args)
def partition_activations_in_checkpoint(partition_activation):
global PARTITION_ACTIVATIONS
PARTITION_ACTIVATIONS = partition_activation
if dist.get_rank() == 0:
print(f"**************Partition Activations {PARTITION_ACTIVATIONS}************")
def set_num_layers(nlayers):
global num_layers
num_layers = nlayers
def reset():
"""Resets memory buffers related to contiguous memory optimizations.
Should be called during eval when multiple forward propagations are
computed without any backward propagation that usually clears these
buffers.
Arguments:
None
Return:
None
"""
if CONTIGUOUS_CHECKPOINTING:
global data_offsets, size_offsets
global contiguous_data_buffers, contiguous_size_buffers
for buffers in contiguous_data_buffers:
buffers = []
#frees up all the pointers to the checkpoints except for the ones
#stored by save for backward
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
size_offsets = []
def _configure_using_config_file(deepspeed_config):
global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
config = DeepSpeedConfig(deepspeed_config).activation_checkpointing_config
print(config.repr())
PARTITION_ACTIVATIONS = config.partition_activations
CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization
num_layers = config.number_checkpoints
PA_TO_CPU = config.cpu_checkpointing
SYNCHRONIZE = config.synchronize_checkpoint_boundary
PROFILE_TIME = config.profile
def _configure_defaults():
global mpu, num_layers, deepspeed_checkpointing_enabled
global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
PARTITION_ACTIVATIONS = False
CONTIGUOUS_CHECKPOINTING = False
num_layers = False
PA_TO_CPU = False
SYNCHRONIZE = False
PROFILE_TIME = False
deepspeed_checkpointing_enabled = True
def configure(
mpu_,
deepspeed_config=None,
partition_activations=None,
contiguous_checkpointing=None,
num_checkpoints=None,
checkpoint_in_cpu=None,
synchronize=None,
profile=None,
):
"""Configure DeepSpeed Activation Checkpointing.
Arguments:
mpu_: Optional: An object that implements the following methods
get_model_parallel_rank/group/world_size, and get_data_parallel_rank/group/world_size
deepspeed_config: Optional: DeepSpeed Config json file when provided will be used to
configure DeepSpeed Activation Checkpointing
partition_activations: Optional: Partitions activation checkpoint across model parallel
GPUs when enabled. By default False. Will overwrite deepspeed_config if provided
contiguous_checkpointing: Optional: Copies activation checkpoints to a contiguous memory
buffer. Works only with homogeneous checkpoints when partition_activations is enabled.
Must provide num_checkpoints. By default False. Will overwrite deepspeed_config if
provided
num_checkpoints: Optional: Number of activation checkpoints stored during the forward
propagation of the model. Used to calculate the buffer size for contiguous_checkpointing
Will overwrite deepspeed_config if provided
checkpoint_in_cpu: Optional: Moves the activation checkpoint to CPU. Only works with
partition_activation. Default is false. Will overwrite deepspeed_config if provided
synchronize: Optional: Performs torch.cuda.synchronize() at the beginning and end of
each call to deepspeed.checkpointing.checkpoint for both forward and backward pass.
By default false. Will overwrite deepspeed_config if provided
profile: Optional: Logs the forward and backward time for each
deepspeed.checkpointing.checkpoint invocation. Will overwrite deepspeed_config
if provided
Returns:
None
"""
global mpu, num_layers, deepspeed_checkpointing_enabled
global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
_configure_defaults()
if deepspeed_config is not None:
_configure_using_config_file(deepspeed_config)
if mpu_ is not None:
mpu = mpu_
if partition_activations is not None:
PARTITION_ACTIVATIONS = partition_activations
if contiguous_checkpointing is not None:
CONTIGUOUS_CHECKPOINTING = contiguous_checkpointing
if num_checkpoints is not None:
num_layers = num_checkpoints
if checkpoint_in_cpu is not None:
PA_TO_CPU = checkpoint_in_cpu
if synchronize is not None:
SYNCHRONIZE = synchronize
if profile is not None:
PROFILE_TIME = profile
if PA_TO_CPU or CONTIGUOUS_CHECKPOINTING:
assert PARTITION_ACTIVATIONS, "CPU Checkpointing/Contiguous Checkpointing is only availble with partitioned activations. Set partitioned activations to true in deepspeed config"
if CONTIGUOUS_CHECKPOINTING:
assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing"
def is_configured():
"""True if deepspeed activation checkpointing has been configured
by calling deepspeed.checkpointing.configure, else returns false
Arguments:
None
Return:
True of configured, else False
"""
global deepspeed_checkpointing_enabled
return deepspeed_checkpointing_enabled
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
from deepspeed.pt.deepspeed_config_utils import get_scalar_param
#########################################
# DeepSpeed Activation Checkpointing
#########################################
# Activation Checkpointing Allows to save memory by only keeping a select few
#activations for the backpropagation.
ACTIVATION_CHKPT_FORMAT = '''
Activation Checkpointing should be configured as:
"session_params": {
"activation_checkpointing": {
"partitioned_activations": [true|false],
"number_checkpoints": 100,
"contiguous_memory_optimization": [true|false],
"cpu_checkpointing": [true|false]
"profile": [true|false],
"synchronize_checkpoint_boundary": [true|false],
}
}
'''
ACT_CHKPT_PARTITION_ACTIVATIONS = 'partition_activations'
ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT = False
ACT_CHKPT_NUMBER_CHECKPOINTS = 'number_checkpoints'
ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT = None
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION = 'contiguous_memory_optimization'
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT = False
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY = 'synchronize_checkpoint_boundary'
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT = False
ACT_CHKPT_PROFILE = 'profile'
ACT_CHKPT_PROFILE_DEFAULT = False
ACT_CHKPT_CPU_CHECKPOINTING = 'cpu_checkpointing'
ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT = False
ACT_CHKPT = 'activation_checkpointing'
ACT_CHKPT_DEFAULT = {
ACT_CHKPT_PARTITION_ACTIVATIONS: ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT,
ACT_CHKPT_NUMBER_CHECKPOINTS: ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT,
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION:
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT,
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY:
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT,
ACT_CHKPT_PROFILE: ACT_CHKPT_PROFILE_DEFAULT,
ACT_CHKPT_CPU_CHECKPOINTING: ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT
}
class DeepSpeedActivationCheckpointingConfig(object):
def __init__(self, param_dict):
super(DeepSpeedActivationCheckpointingConfig, self).__init__()
self.partition_activations = None
self.contiguous_memory_optimization = None
self.cpu_checkpointing = None
self.number_checkpoints = None
self.synchronize_checkpoint_boundary = None
self.profile = None
if ACT_CHKPT in param_dict.keys():
act_chkpt_config_dict = param_dict[ACT_CHKPT]
else:
act_chkpt_config_dict = ACT_CHKPT_DEFAULT
self._initialize(act_chkpt_config_dict)
"""
For json serialization
"""
def repr(self):
return self.__dict__
def _initialize(self, act_chkpt_config_dict):
self.partition_activations = get_scalar_param(
act_chkpt_config_dict,
ACT_CHKPT_PARTITION_ACTIVATIONS,
ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT)
self.contiguous_memory_optimization = get_scalar_param(
act_chkpt_config_dict,
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION,
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT)
self.cpu_checkpointing = get_scalar_param(act_chkpt_config_dict,
ACT_CHKPT_CPU_CHECKPOINTING,
ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT)
self.number_checkpoints = get_scalar_param(act_chkpt_config_dict,
ACT_CHKPT_NUMBER_CHECKPOINTS,
ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT)
self.profile = get_scalar_param(act_chkpt_config_dict,
ACT_CHKPT_PROFILE,
ACT_CHKPT_PROFILE_DEFAULT)
self.synchronize_checkpoint_boundary = get_scalar_param(
act_chkpt_config_dict,
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY,
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT)
......@@ -8,6 +8,9 @@ import logging
import json
from deepspeed.pt.deepspeed_constants import *
from deepspeed.pt.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
from deepspeed.pt.deepspeed_config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
from deepspeed.pt.deepspeed_zero_config import DeepSpeedZeroConfig
from deepspeed.pt.deepspeed_checkpointing_config import DeepSpeedActivationCheckpointingConfig
TENSOR_CORE_ALIGN_SIZE = 8
ADAM_OPTIMIZER = 'adam'
......@@ -15,13 +18,6 @@ LAMB_OPTIMIZER = 'lamb'
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER]
def get_scalar_param(param_dict, param_name, param_default_value):
if param_name in param_dict.keys():
return param_dict[param_name]
else:
return param_default_value
def get_fp16_enabled(param_dict):
if FP16 in param_dict.keys():
return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT)
......@@ -92,10 +88,20 @@ def get_sparse_gradients_enabled(param_dict):
return get_scalar_param(param_dict, SPARSE_GRADIENTS, SPARSE_GRADIENTS_DEFAULT)
def get_zero_enabled(param_dict):
def get_zero_optimization(param_dict):
return get_scalar_param(param_dict, ZERO_OPTIMIZATION, ZERO_OPTIMIZATION_DEFAULT)
def get_zero_reduce_scatter(param_dict):
return get_scalar_param(param_dict, ZERO_REDUCE_SCATTER, ZERO_REDUCE_SCATTER_DEFAULT)
def get_zero_max_elements_per_comm(param_dict):
return get_scalar_param(param_dict,
ZERO_MAX_ELEMENTS_PER_COMM,
ZERO_MAX_ELEMENTS_PER_COMM_DEFAULT)
def get_allgather_size(param_dict):
return get_scalar_param(param_dict,
ALLGATHER_SIZE,
......@@ -204,6 +210,10 @@ def get_wall_clock_breakdown(param_dict):
WALL_CLOCK_BREAKDOWN_DEFAULT)
def get_memory_breakdown(param_dict):
return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT)
def get_tensorboard_enabled(param_dict):
if TENSORBOARD in param_dict.keys():
return get_scalar_param(param_dict[TENSORBOARD],
......@@ -231,10 +241,39 @@ def get_tensorboard_job_name(param_dict):
return TENSORBOARD_JOB_NAME_DEFAULT
'''Write deepspeed config files by modifying basic templates.
Can be used for quicly changing parameters via command line parameters.'''
class DeepSpeedConfigWriter:
def __init__(self, data=None):
self.data = data if data is not None else {}
def add_config(self, key, value):
self.data[key] = value
def load_config(self, filename):
self.data = json.load(open(filename,
'r'),
object_pairs_hook=dict_raise_error_on_duplicate_keys)
def write_config(self, filename):
with open(filename, 'w') as outfile:
json.dump(self.data, outfile)
class DeepSpeedConfig(object):
def __init__(self, json_file, mpu=None):
def __init__(self, json_file, mpu=None, param_dict=None):
super(DeepSpeedConfig, self).__init__()
self._param_dict = json.load(open(json_file, 'r'))
if param_dict is None:
self._param_dict = json.load(
open(json_file,
'r'),
object_pairs_hook=dict_raise_error_on_duplicate_keys)
else:
self._param_dict = param_dict
try:
self.global_rank = torch.distributed.get_rank()
if mpu is None:
......@@ -263,7 +302,14 @@ class DeepSpeedConfig(object):
self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)
self.allgather_size = get_allgather_size(param_dict)
self.zero_enabled = get_zero_enabled(param_dict)
self.zero_config = DeepSpeedZeroConfig(param_dict)
self.zero_optimization_stage = self.zero_config.stage
self.zero_enabled = self.zero_optimization_stage > 0
self.activation_checkpointing_config = DeepSpeedActivationCheckpointingConfig(
param_dict)
self.gradient_clipping = get_gradient_clipping(param_dict)
self.fp16_enabled = get_fp16_enabled(param_dict)
self.loss_scale = get_loss_scale(param_dict)
......@@ -285,6 +331,7 @@ class DeepSpeedConfig(object):
self.scheduler_params = get_scheduler_params(param_dict)
self.wall_clock_breakdown = get_wall_clock_breakdown(param_dict)
self.memory_breakdown = get_memory_breakdown(param_dict)
self.tensorboard_enabled = get_tensorboard_enabled(param_dict)
self.tensorboard_output_path = get_tensorboard_output_path(param_dict)
self.tensorboard_job_name = get_tensorboard_job_name(param_dict)
......@@ -305,8 +352,8 @@ class DeepSpeedConfig(object):
f'Gradient accumulation steps: {grad_acc} has to be greater than 0'
assert train_batch == micro_batch * grad_acc * self.world_size, \
(f'Check batch related parameters. Train_batch_size is not equal'
'to micro_batch_per_gpu * gradient_acc_step * world_size'
(f'Check batch related parameters. train_batch_size is not equal'
' to micro_batch_per_gpu * gradient_acc_step * world_size'
f'{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}')
def _set_batch_related_parameters(self):
......@@ -387,6 +434,7 @@ class DeepSpeedConfig(object):
def _do_error_check(self):
if self.zero_enabled:
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
......
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
"""
Collection of DeepSpeed configuration utilities
"""
def get_scalar_param(param_dict, param_name, param_default_value):
if param_name in param_dict.keys():
return param_dict[param_name]
else:
return param_default_value
def dict_raise_error_on_duplicate_keys(ordered_pairs):
"""Reject duplicate keys."""
d = {}
for k, v in ordered_pairs:
if k in d:
raise ValueError("Duplicate key in DeepSpeed config: %r" % (k, ))
else:
d[k] = v
return d
......@@ -15,7 +15,7 @@ ROUTE_ENCODE = "encode"
# Batch size
#############################################
TRAIN_BATCH_SIZE = "train_batch_size"
TRAIN_BATCH_SIZE_DEFAULT = 1
TRAIN_BATCH_SIZE_DEFAULT = None
#############################################
# Optimizer and lr scheduler
......@@ -133,14 +133,27 @@ GRADIENT_CLIPPING_DEFAULT = 0.
# ZeRO optimization
#########################################
# ZeRO optimization. By default, this optimization is not enabled.
# Users can configure in ds_config.json as below example:
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
ZERO_FORMAT = '''
ZeRO optimization should be enabled as:
"zero_optimization": true,
"zero_all_gather_size": 200
"session_params": {
"zero_optimization": [0|1|2],
"zero_all_gather_size": 200
}
'''
ZERO_OPTIMIZATION = 'zero_optimization'
ZERO_OPTIMIZATION_DEFAULT = False
ZERO_OPTIMIZATION_DEFAULT = 0
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
ZERO_OPTIMIZATION_GRADIENTS = 2
ZERO_OPTIMIZATION_WEIGHTS = 3
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
ZERO_REDUCE_SCATTER = "zero_reduce_scatter"
ZERO_REDUCE_SCATTER_DEFAULT = True
ZERO_MAX_ELEMENTS_PER_COMM = "zero_max_elements_per_comm"
ZERO_MAX_ELEMENTS_PER_COMM_DEFAULT = 5e8
ALLGATHER_SIZE = 'allgather_size'
ALLGATHER_SIZE_DEFAULT = 500000000
......@@ -217,6 +230,9 @@ Wall block breakdown should be enabled as:
WALL_CLOCK_BREAKDOWN = 'wall_clock_breakdown'
WALL_CLOCK_BREAKDOWN_DEFAULT = False
MEMORY_BREAKDOWN = 'memory_breakdown'
MEMORY_BREAKDOWN_DEFAULT = False
#########################################
# Tensorboard
#########################################
......
......@@ -8,11 +8,14 @@ import os
import warnings
import torch.distributed as dist
from torch.nn.modules import Module
from torch.distributed.distributed_c10d import _get_global_rank
from tensorboardX import SummaryWriter
from deepspeed.pt.deepspeed_timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer
from deepspeed.pt.zero_optimizer_stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
import deepspeed.pt.deepspeed_checkpointing as deepspeed_activation_checkpointing
from deepspeed.pt.fp16_optimizer import FP16_Optimizer
from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer
......@@ -21,8 +24,10 @@ from deepspeed.pt.deepspeed_config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS
from deepspeed.pt.deepspeed_dataloader import DeepSpeedDataLoader
from deepspeed.pt.deepspeed_constants import ROUTE_TRAIN, ROUTE_PREDICT, \
ROUTE_EVAL, TORCH_DISTRIBUTED_DEFAULT_PORT
from deepspeed.pt.deepspeed_constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
TORCH_DISTRIBUTED_DEFAULT_PORT, \
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
import deepspeed.pt.deepspeed_lr_schedules as lr_schedules
from deepspeed.pt.deepspeed_csr_tensor import CSRTensor
......@@ -96,7 +101,8 @@ class DeepSpeedLight(Module):
lr_scheduler=None,
mpu=None,
dist_init_required=None,
collate_fn=None):
collate_fn=None,
config_params=None):
super(DeepSpeedLight, self).__init__()
logging.basicConfig(level=logging.INFO,
......@@ -116,6 +122,7 @@ class DeepSpeedLight(Module):
self.gradient_predivide_factor = 1.0
self.gradient_average = True
self.warn_unscaled_loss = True
self.config_params = config_params
if dist_init_required is None:
dist_init_required = not dist.is_initialized()
......@@ -146,6 +153,9 @@ class DeepSpeedLight(Module):
# Configure distributed model
self._configure_distributed_model(model)
# Configure wall clock timer
self.timers = SynchronizedWallClockTimer()
# Throughput timer
self.tput_timer = ThroughputTimer(
batch_size=self.train_micro_batch_size_per_gpu(),
......@@ -163,9 +173,6 @@ class DeepSpeedLight(Module):
self._configure_lr_scheduler(lr_scheduler)
self._report_progress(0)
# Configure wall clock timer
self.timers = SynchronizedWallClockTimer()
# Bookkeeping for csr support
self.csr_tensor_module_names = set()
if self.sparse_gradients_enabled():
......@@ -245,6 +252,9 @@ class DeepSpeedLight(Module):
def wall_clock_breakdown(self):
return self._config.wall_clock_breakdown
def memory_breakdown(self):
return self._config.memory_breakdown
def sparse_gradients_enabled(self):
return self._config.sparse_gradients_enabled
......@@ -275,6 +285,30 @@ class DeepSpeedLight(Module):
def zero_allow_untested_optimizer(self):
return self._config.zero_allow_untested_optimizer
def zero_reduce_scatter(self):
return self._config.zero_config.reduce_scatter
def zero_overlap_comm(self):
return self._config.zero_config.overlap_comm
def zero_max_elements_per_comm(self):
return self._config.zero_max_elements_per_comm
def zero_optimization_stage(self):
return self._config.zero_optimization_stage
def zero_reduce_bucket_size(self):
return self._config.zero_config.reduce_bucket_size
def zero_allgather_bucket_size(self):
return self._config.zero_config.allgather_bucket_size
def zero_optimization_partition_gradients(self):
return self.zero_optimization_stage() >= ZERO_OPTIMIZATION_GRADIENTS
def zero_contiguous_gradients(self):
return self._config.zero_config.contiguous_gradients
def allgather_size(self):
return self._config.allgather_size
......@@ -296,8 +330,8 @@ class DeepSpeedLight(Module):
def steps_per_print(self):
return self._config.steps_per_print
def disable_allgather(self):
return self._config.disable_allgather
def zero_allgather_partitions(self):
return self._config.zero_config.allgather_partitions
def dump_state(self):
return self._config.dump_state
......@@ -375,7 +409,9 @@ class DeepSpeedLight(Module):
# Configure based on command line arguments
def _configure_with_arguments(self, args, mpu):
self.local_rank = args.local_rank if hasattr(args, 'local_rank') else 0
self._config = DeepSpeedConfig(args.deepspeed_config, mpu)
self._config = DeepSpeedConfig(args.deepspeed_config,
mpu,
param_dict=self.config_params)
# Validate command line arguments
def _do_args_sanity_check(self, args):
......@@ -390,11 +426,12 @@ class DeepSpeedLight(Module):
assert hasattr(args, 'local_rank') and type(args.local_rank) == int, \
'DeepSpeed requires integer command line parameter --local_rank'
assert hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None, \
'DeepSpeed requires --deepspeed_config to specify configuration file'
if self.config_params is None:
assert hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None, \
'DeepSpeed requires --deepspeed_config to specify configuration file'
assert os.path.isfile(args.deepspeed_config), \
'DeepSpeed configuration file: {} is not an existing file'.format(args.deepspeed_config)
assert os.path.isfile(args.deepspeed_config), \
'DeepSpeed configuration file: {} is not an existing file'.format(args.deepspeed_config)
def _is_supported_optimizer(self, optimizer_name):
return optimizer_name in DEEPSPEED_OPTIMIZERS or \
......@@ -424,7 +461,8 @@ class DeepSpeedLight(Module):
else:
self.data_parallel_group = self.mpu.get_data_parallel_group()
self.dp_world_size = self.mpu.get_data_parallel_world_size()
src_rank = self.mpu.get_model_parallel_rank()
src_rank = _get_global_rank(self.mpu.get_data_parallel_group(), 0)
print(f"global src_rank={src_rank}")
for p in self.module.parameters():
if torch.is_tensor(p):
dist.broadcast(p, src_rank, group=self.data_parallel_group)
......@@ -518,17 +556,42 @@ class DeepSpeedLight(Module):
return optimizer
def _configure_zero_optimizer(self, optimizer):
logging.info('Creating fp16 zero optimizer')
optimizer = FP16_DeepSpeedZeroOptimizer(
optimizer,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
dp_process_group=self.data_parallel_group,
clip_grad=self.gradient_clipping(),
all_gather_partitions=not self.disable_allgather(),
allgather_size=self.allgather_size(),
mpu=self.mpu)
zero_stage = self.zero_optimization_stage()
logging.info('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage))
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode'
logging.info('Creating fp16 ZeRO Optimizer Stage 1')
optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
optimizer,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
all_gather_partitions=self.zero_allgather_partitions(),
allgather_size=self.zero_allgather_bucket_size(),
max_elements_per_comm=self.zero_reduce_bucket_size(),
dp_process_group=self.data_parallel_group,
mpu=self.mpu)
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
assert self.gradient_accumulation_steps() == 1, "ZeRO stage 2 does not support gradient accumulation, if you need gradient accumulation please use stage 1"
optimizer = FP16_DeepSpeedZeroOptimizer(
optimizer,
timers=self.timers,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=self.zero_contiguous_gradients(),
reduce_bucket_size=self.zero_reduce_bucket_size(),
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
mpu=self.mpu)
else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
logging.info('Creating fp16 zero stage {} optimizer'.format(zero_stage))
return optimizer
......@@ -624,7 +687,16 @@ class DeepSpeedLight(Module):
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
if self.is_gradient_accumulation_boundary():
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
if self.zero_optimization_stage() == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter()
self.optimizer.reduce_scatter_gradients(
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor,
gradient_average=self.gradient_average)
elif self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
else:
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
def backward(self, loss, allreduce_gradients=True):
r"""Execute backward pass on the loss
......@@ -636,7 +708,7 @@ class DeepSpeedLight(Module):
# scale loss w.r.t. gradient accumulation if needed
if self.gradient_accumulation_steps() > 1:
loss = self._scale_loss(loss)
loss = self._scale_loss(loss.float())
# Log training Loss
if self.tensorboard_enabled():
......@@ -765,27 +837,28 @@ class DeepSpeedLight(Module):
'backward_inner_microstep',
'backward_allreduce_microstep',
'step_microstep'
])
# Log timing
if self.tensorboard_enabled():
if self.is_gradient_accumulation_boundary():
if self.global_rank == 0:
self.summary_events = [(f'Train/Samples/elapsed_time_ms_forward', self.timers('forward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward', self.timers('backward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward_inner', self.timers('backward_inner').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward_allreduce', self.timers('backward_allreduce').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_step', self.timers('step').elapsed(reset=False) * 1000.0, self.sample_count)
]
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
self.summary_writer.flush()
self.timers.log([
'forward',
'backward',
'backward_inner',
'backward_allreduce',
'step'
])
],
memory_breakdown=self.memory_breakdown())
if self.is_gradient_accumulation_boundary():
if self.tensorboard_enabled() and torch.distributed.get_rank(
) == 0: # this is done before the log because log resets timers
self.summary_events = [(f'Train/elapsed_time_ms_forward', self.timers('forward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/elapsed_time_ms_backward', self.timers('backward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/elapsed_time_ms_backward_inner', self.timers('backward_inner').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/elapsed_time_ms_backward_allreduce', self.timers('backward_allreduce').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/elapsed_time_ms_step', self.timers('step').elapsed(reset=False) * 1000.0, self.sample_count)
]
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
self.summary_writer.flush()
self.timers.log([
'forward',
'backward',
'backward_inner',
'backward_allreduce',
'step'
])
self.micro_steps += 1
......@@ -971,19 +1044,30 @@ class DeepSpeedLight(Module):
if not os.path.exists(dirname):
os.makedirs(dirname)
def load_checkpoint(self, load_dir, tag, load_optimizer_states=True):
def load_checkpoint(self,
load_dir,
tag,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True):
r"""Load training checkpoint
Arguments:
load_dir: Required. Directory to load the checkpoint from
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
Return:
load_path: Path of the loaded checkpoint. None if loading the checkpoint failed
client_state: State dictionary used for loading required training states in the client code.
"""
load_path, client_states = self._load_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
load_path, client_states = self._load_checkpoint(load_dir,
tag,
load_module_strict=load_module_strict,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)
if self.zero_optimization() and load_path is not None:
self._load_zero_checkpoint(load_dir,
......@@ -992,7 +1076,12 @@ class DeepSpeedLight(Module):
return load_path, client_states
def _load_checkpoint(self, load_dir, tag, load_optimizer_states=True):
def _load_checkpoint(self,
load_dir,
tag,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True):
load_path = self._get_ckpt_name(load_dir, tag)
......@@ -1005,12 +1094,13 @@ class DeepSpeedLight(Module):
logging.info('Loading checkpoint: {}'.format(load_path))
checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)
self.load_module_state_dict(checkpoint['module'])
self.load_module_state_dict(state_dict=checkpoint['module'],
strict=load_module_strict)
if not self.zero_optimization():
self.optimizer.load_state_dict(checkpoint['optimizer'],
load_optimizer_states=load_optimizer_states)
if self.lr_scheduler is not None:
if load_lr_scheduler_states and self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
self.csr_tensor_module_names = checkpoint['csr_tensor_module_names']
......@@ -1019,6 +1109,7 @@ class DeepSpeedLight(Module):
deepspeed_states = [
'module',
'optimizer',
'lr_scheduler',
'csr_tensor_module_names',
'skipped_steps',
'global_steps'
......@@ -1058,19 +1149,15 @@ class DeepSpeedLight(Module):
#There seems to be issue creating them in parallel
self._create_checkpoint_files(save_dir, tag)
try:
if self.save_non_zero_checkpoint:
self._save_checkpoint(save_dir, tag, client_state=client_state)
if self.save_non_zero_checkpoint:
self._save_checkpoint(save_dir, tag, client_state=client_state)
if self.save_zero_checkpoint:
self._save_zero_checkpoint(save_dir, tag)
if self.save_zero_checkpoint:
self._save_zero_checkpoint(save_dir, tag)
except:
logging.error(f'Failed Saving model checkpoint to {save_dir} with tag {tag}')
return False
return True
def _create_checkpoint_files(self, save_dir, tag):
#checkpoint files are created sequentially
for rank in range(self.world_size):
if rank == self.global_rank:
......@@ -1114,14 +1201,8 @@ class DeepSpeedLight(Module):
torch.save(state, save_path)
def _save_zero_checkpoint(self, save_path, tag):
try:
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
#self._ensure_directory_exists(zero_checkpoint_name)
except:
logging.error(
f'Failed Saving Zero model checkpoint to {save_path} with tag {tag}')
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
#self._ensure_directory_exists(zero_checkpoint_name)
zero_sd = {'optimizer_state_dict': self.optimizer.state_dict()}
torch.save(zero_sd, zero_checkpoint_name)
logging.info('zero checkpoint saved {}'.format(zero_checkpoint_name))
......@@ -69,13 +69,27 @@ class SynchronizedWallClockTimer:
self.timers[name] = self.Timer(name)
return self.timers[name]
def log(self, names, normalizer=1.0, reset=True):
@staticmethod
def memory_usage():
alloc = "mem_allocated: {:.4f} GB".format(torch.cuda.memory_allocated() /
(1024 * 1024 * 1024))
max_alloc = "max_mem_allocated: {:.4f} GB".format(
torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024))
cache = "cache_allocated: {:.4f} GB".format(torch.cuda.memory_cached() /
(1024 * 1024 * 1024))
max_cache = "max_cache_allocated: {:.4f} GB".format(
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024))
return " | {} | {} | {} | {}".format(alloc, max_alloc, cache, max_cache)
def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False):
"""Log a group of timers."""
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if memory_breakdown:
string += self.memory_usage()
print_rank_0(string)
......
......@@ -12,9 +12,10 @@ from torch._six import inf
class CheckOverflow(object):
'''Checks for overflow in gradient across parallel process'''
def __init__(self, param_groups=None, mpu=None):
def __init__(self, param_groups=None, mpu=None, zero_reduce_scatter=False):
self.mpu = mpu
self.params = [] if param_groups else None
self.zero_reduce_scatter = zero_reduce_scatter
if param_groups:
for group in param_groups:
for param in group:
......@@ -54,8 +55,8 @@ class CheckOverflow(object):
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params):
for p in params:
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
for i, p in enumerate(params):
if p.grad is not None and self._has_inf_or_nan(p.grad.data, i):
return True
return False
......@@ -67,7 +68,11 @@ class CheckOverflow(object):
#torch.distributed.all_reduce(overflow_gpu,
# op=torch.distributed.ReduceOp.MAX,
# group=mpu.get_model_parallel_group())
if self.mpu is not None:
if self.zero_reduce_scatter:
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=torch.distributed.group.WORLD)
elif self.mpu is not None:
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=self.mpu.get_model_parallel_group())
......@@ -76,7 +81,7 @@ class CheckOverflow(object):
# `x` is a torch.Tensor
@staticmethod
def _has_inf_or_nan(x):
def _has_inf_or_nan(x, i):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
......@@ -93,10 +98,25 @@ class CheckOverflow(object):
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
_handle_overflow(cpu_sum, x, i)
return True
return False
def _handle_overflow(cpu_sum, x, i):
import math
rank = torch.distributed.get_rank()
if rank == 0:
t_i = -1
for v_i, v in enumerate(x.data.contiguous().view(-1)):
if not math.isfinite(float(v)):
t_i = v_i
break
print(
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
)
def get_grad_norm(parameters, norm_type=2, mpu=None):
"""Clips gradient norm of an iterable of parameters.
......@@ -221,3 +241,33 @@ def get_weight_norm(parameters, norm_type=2, mpu=None):
total_norm = -1
return total_norm
def is_model_parallel_parameter(p):
return hasattr(p, 'model_parallel') and p.model_parallel
def see_memory_usage(message):
return
if torch.distributed.is_initialized() and not torch.distributed.get_rank() == 0:
return
# Print message except when distributed but not rank 0
print(message, flush=True)
print("Memory Allocated ",
torch.cuda.memory_allocated() / (1024 * 1024 * 1024),
"GigaBytes",
flush=True)
print("Max Memory Allocated ",
torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),
"GigaBytes",
flush=True)
print("Cache Allocated ",
torch.cuda.memory_cached() / (1024 * 1024 * 1024),
"GigaBytes",
flush=True)
print("Max cache Allocated ",
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
"GigaBytes",
flush=True)
print(" ", flush=True)
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
import logging
#from deepspeed.pt.deepspeed_constants import *
from deepspeed.pt.deepspeed_config_utils import get_scalar_param
#########################################
# ZeRO optimization
#########################################
# ZeRO optimization. By default, this optimization is not enabled.
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
ZERO_FORMAT = '''
ZeRO optimization should be enabled as:
"session_params": {
"zero_optimization": {
"stage": [0|1|2],
"allgather_partitions": [true|false],
"allgather_bucket_size": 500000000,
"reduce_scatter": [true|false],
"contiguous_gradients" : [true|false]
"overlap_comm": [true|false],
"reduce_bucket_size": 500000000
}
}
'''
ZERO_OPTIMIZATION = 'zero_optimization'
ZERO_OPTIMIZATION_DISABLED = 0
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
ZERO_OPTIMIZATION_GRADIENTS = 2
ZERO_OPTIMIZATION_WEIGHTS = 3
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
ZERO_OPTIMIZATION_STAGE = 'stage'
ZERO_OPTIMIZATION_STAGE_1 = 'stage_1'
ZERO_OPTIMIZATION_STAGE_2 = 'stage_2'
ZERO_OPTIMIZATION_STAGE_3 = 'stage_3'
ZERO_OPTIMIZATION_STAGE_DEFAULT = ZERO_OPTIMIZATION_DISABLED
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS = 'allgather_partitions'
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT = True
ZERO_OPTIMIZATION_REDUCE_SCATTER = 'reduce_scatter'
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT = True
ZERO_OPTIMIZATION_OVERLAP_COMM = 'overlap_comm'
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT = False
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS = 'contiguous_gradients'
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT = True
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE = 'reduce_bucket_size'
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE = 'allgather_bucket_size'
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED = 'allgather_size'
ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT
}
class DeepSpeedZeroConfig(object):
def __init__(self, param_dict):
super(DeepSpeedZeroConfig, self).__init__()
self.stage = None
self.contiguous_gradients = None
self.reduce_scatter = None
self.reduce_bucket_size = None
self.allgather_partitions = None
self.allgather_bucket_size = None
self.overlap_comm = None
if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION]
if type(zero_config_dict) is bool:
zero_config_dict = self.read_zero_config_deprecated(param_dict)
else:
zero_config_dict = ZERO_OPTIMIZATION_DEFAULT
self._initialize(zero_config_dict)
def read_zero_config_deprecated(self, param_dict):
zero_config_dict = {}
zero_config_dict[
ZERO_OPTIMIZATION_STAGE] = 1 if param_dict[ZERO_OPTIMIZATION] else 0
if zero_config_dict[ZERO_OPTIMIZATION_STAGE] > 0:
zero_config_dict[ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE] = get_scalar_param(
param_dict,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)
logging.warning(
'DeepSpeedConfig: this format of ZeRO optimization setup is deprecated. Please use the following format: {}'
.format(ZERO_FORMAT))
return zero_config_dict
"""
For json serialization
"""
def repr(self):
return self.__dict__
def _initialize(self, zero_config_dict):
self.stage = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_STAGE,
ZERO_OPTIMIZATION_STAGE_DEFAULT)
self.contiguous_gradients = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT)
self.reduce_bucket_size = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT)
self.reduce_scatter = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_REDUCE_SCATTER,
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT)
self.overlap_comm = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_OVERLAP_COMM,
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT)
self.allgather_partitions = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT)
self.allgather_bucket_size = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)
'''
Copyright 2019 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from FP16_Optimizer in NVIDIA/apex
'''
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed.distributed_c10d import _get_global_rank
import torch.distributed as dist
import math
from torch._six import inf
from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow
# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(tensor_list, alignment, pg):
num_elements = 0
for tensor in tensor_list:
num_elements = num_elements + tensor.numel()
remaining = num_elements % alignment
if remaining:
elements_to_add = alignment - remaining
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
num_elements = num_elements + elements_to_add
else:
padded_tensor_list = tensor_list
if dist.get_rank(group=pg) == 0:
print("Number of Elements is ", num_elements)
return _flatten_dense_tensors(padded_tensor_list)
def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
data_parallel_size = int(dist.get_world_size())
if parameter_parallel_size is None:
parameter_parallel_size = int(data_parallel_size)
print(data_parallel_size, parameter_parallel_size)
assert data_parallel_size % parameter_parallel_size == 0, \
'world size should be divisible by parameter parallel size'
rank = dist.get_rank()
my_group = None
for i in range(dist.get_world_size() // parameter_parallel_size):
ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
my_group = group
return my_group
class FP16_DeepSpeedZeroOptimizer(object):
"""
DeepSpeedZeroOptimizer designed to reduce the memory footprint
required for training large deep learning models.
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
For usage examples, refer to TODO: DeepSpeed V2 Tutorial
"""
def __init__(self,
init_optimizer,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
dp_process_group=None,
partition_size=None,
mpu=None,
all_gather_partitions=True,
allgather_size=500000000,
clip_grad=0.0):
if dp_process_group is not None and partition_size is not None:
raise ValueError("Cannot specify both dp_process_group "
"and partition size")
if dp_process_group is None:
dp_process_group = _initialize_parameter_parallel_groups(partition_size)
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
self.verbose = verbose
self.dp_process_group = dp_process_group
# TODO: automatically turn off if #params > some_limit
self.all_gather_partitions = all_gather_partitions
self.allgather_size = allgather_size
# param flattened by groups
self.fp16_groups = []
self.fp16_groups_flat = []
#param partitioned by data parallel degree
#this will contain a list of equal sized tensors
#each of which will be updated by a different process
self.parallel_partitioned_fp16_groups = []
#a single 32-bit partition of the parallel partitioned parameters
#that this process will update
self.single_partition_of_fp32_groups = []
#param partition info
#These are the parameters in each group that will not be updated by this process directly
self.params_not_in_partition = []
#These are the parameters that will be updated by this process directly
self.params_in_partition = []
#Offset from the first paramter in the the self.params_in_partition
#the parameter boundaries may not align with partition boundaries
#so we need to keep track of the offset
self.first_offset = []
#number of elements per partition in each group
self.partition_size = []
partition_id = dist.get_rank(group=self.dp_process_group)
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
self.fp16_groups.append(param_group['params'])
self.fp16_groups_flat.append(
flatten_dense_tensors_aligned(
self.fp16_groups[i],
dist.get_world_size(group=self.dp_process_group),
self.dp_process_group))
# set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
#divide the flat weights into near equal paritition equal to the data parallel degree
#each process will compute on a different part of the partition
data_parallel_partitions = self.get_data_parallel_partitions(
self.fp16_groups_flat[i])
self.parallel_partitioned_fp16_groups.append(data_parallel_partitions)
# a partition of the fp32 master weights that will be updated by this process
self.single_partition_of_fp32_groups.append(
self.parallel_partitioned_fp16_groups[i]
[partition_id].clone().float().detach())
# modify optimizer of have flat master weight
self.single_partition_of_fp32_groups[
i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.single_partition_of_fp32_groups[i]]
partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(
group=self.dp_process_group)
params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(self.fp16_groups[i], partition_size, partition_id)
self.partition_size.append(partition_size)
self.params_in_partition.append(params_in_partition)
self.params_not_in_partition.append(params_not_in_partition)
self.first_offset.append(first_offset)
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
self.dynamic_loss_scale = True
if dynamic_loss_args is None:
self.loss_scaler = DynamicLossScaler()
else:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
else:
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(scale=static_loss_scale)
self.cur_iter = 0
self.mpu = mpu
self.clip_grad = clip_grad
self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu)
#views the tensor as multiple partitions and returns
#those partitions
def get_data_parallel_partitions(self, tensor):
partitions = []
dp = dist.get_world_size(group=self.dp_process_group)
total_num_elements = tensor.numel()
base_size = total_num_elements // dp
remaining = total_num_elements % dp
start = 0
for id in range(dp):
partition_size = base_size
if id < remaining:
partition_size = partition_size + 1
partitions.append(tensor.narrow(0, start, partition_size))
start = start + partition_size
return partitions
def get_partition_info(self, tensor_list, partition_size, partition_id):
params_in_partition = []
params_not_in_partition = []
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for tensor in tensor_list:
tensor_size = tensor.numel()
if (current_index >= start_index and current_index < end_index):
params_in_partition.append(tensor)
elif start_index > current_index and start_index < (current_index +
tensor_size):
params_in_partition.append(tensor)
assert (first_offset==0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
else:
params_not_in_partition.append(tensor)
current_index = current_index + tensor_size
return params_in_partition, params_not_in_partition, first_offset
def zero_grad(self, set_grads_to_None=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
#creates a flat fused tensor from the tensor list starting at the first_offset
#in the first tensor of the list. If there are not enough elements in the tensor
#list then the flat tensor will be padded with zeros
def get_flat_partition(self, tensor_list, first_offset, partition_size, dtype=None):
flat_tensor_list = []
current_size = 0
if not tensor_list:
flat_tensor_list.append(
torch.zeros(int(partition_size),
dtype=dtype,
device=torch.cuda.current_device()))
return _flatten_dense_tensors(flat_tensor_list)
if dtype is None:
dtype = tensor_list[0].dtype
for i, tensor in enumerate(tensor_list):
if tensor.grad is None:
tensor.grad = torch.zeros(tensor.size(),
dtype=tensor.dtype,
device=tensor.device)
tensor = tensor.grad
num_elements = tensor.numel()
tensor_offset = 0
#we need to offset to get to the right element
if i == 0 and first_offset > 0:
tensor_offset = first_offset
num_elements = num_elements - tensor_offset
#we dont need all elements of the tensor
if num_elements > (partition_size - current_size):
num_elements = partition_size - current_size
#we need a narrow view of the tensor based on the tensor offset and number of elements that
#we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel():
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
0,
int(tensor_offset),
int(num_elements)).to(dtype))
else:
flat_tensor_list.append(tensor.to(dtype))
current_size = current_size + num_elements
#this means its the last partition and does not align with the dp boundary. We need to pad before flattening
if current_size < partition_size:
flat_tensor_list.append(
torch.zeros(int(partition_size - current_size),
dtype=dtype,
device=tensor_list[0].device))
return _flatten_dense_tensors(flat_tensor_list)
def free_grad_in_param_list(self, param_list):
for p in param_list:
p.grad = None
def see_memory_usage(self):
print("Memory Allocated ",
torch.cuda.memory_allocated() / (1024 * 1024 * 1024),
"GigaBytes")
print("Max Memory Allocated ",
torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),
"GigaBytes")
print("Cache Allocated ",
torch.cuda.memory_cached() / (1024 * 1024 * 1024),
"GigaBytes")
print("Max cache Allocated ",
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
"GigaBytes")
def print_first_n(self, caption, tensor, n=10):
if tensor is not None:
print(caption,
tensor.data.contiguous().view(-1).narrow(0,
0,
n).cpu().numpy())
else:
print(caption, None)
def step(self, closure=None):
"""
Not supporting closure.
"""
# First compute norm for all group so we know if there is overflow
self.overflow = self.overflow_checker.check()
prev_scale = self.loss_scale
self._update_scale(self.overflow)
if self.overflow:
self.zero_grad()
if self.verbose:
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
self.loss_scale))
return self.overflow
norm_groups = []
single_partition_grad_groups = []
partition_id = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups):
norm_groups.append(get_grad_norm(group, mpu=self.mpu))
#free gradients for all the parameters that are not updated by this process
self.free_grad_in_param_list(self.params_not_in_partition[i])
#create a flat gradients for parameters updated by this process
single_grad_partition = self.get_flat_partition(
self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=self.single_partition_of_fp32_groups[i].dtype)
self.single_partition_of_fp32_groups[i].grad = single_grad_partition
#release all the gradient since we have already created a necessary copy in dp_grad_partition
self.free_grad_in_param_list(self.params_in_partition[i])
single_partition_grad_groups.append(single_grad_partition)
self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
self.optimizer.step()
#get rid of the fp32 gradients. Not needed anymore
for group in self.single_partition_of_fp32_groups:
group.grad = None
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)
dp_world_size = dist.get_world_size(group=self.dp_process_group)
#gather the updated weights from everyone
for _, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups):
if self.all_gather_partitions:
# controllable memory-time tradeoff
num_shards = max(
1,
partitioned_params[partition_id].numel() * dp_world_size //
self.allgather_size)
shard_size = partitioned_params[partition_id].numel() // num_shards
num_elements = shard_size
for shard_id in range(num_shards + 1):
if shard_id == num_shards:
if shard_size * num_shards >= partitioned_params[
partition_id].numel():
break
else:
num_elements = partitioned_params[partition_id].numel(
) - shard_id * shard_size
shard_list = []
for dp_id in range(dp_world_size):
curr_shard = partitioned_params[dp_id].narrow(
0,
shard_id * shard_size,
num_elements)
shard_list.append(curr_shard)
dist.all_gather(shard_list,
shard_list[partition_id],
group=self.dp_process_group)
else:
#this should require less memory but should be faster
for src, partitioned_param in enumerate(partitioned_params):
global_src = _get_global_rank(self.dp_process_group, src)
dist.broadcast(partitioned_param,
global_src,
group=self.dp_process_group)
# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
return self.overflow
def unscale_and_clip_grads(self, grad_groups_flat, norm_groups):
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)
# compute combined scale factor for this group
combined_scale = self.loss_scale
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.loss_scale
for grad in grad_groups_flat:
grad.data.mul_(1. / combined_scale)
def backward(self, loss, retain_graph=False):
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict[
'single_partition_of_fp32_groups'] = self.single_partition_of_fp32_groups
return state_dict
def load_state_dict(self, state_dict, load_optimizer_states=True):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
for current, saved in zip(self.single_partition_of_fp32_groups, state_dict['single_partition_of_fp32_groups']):
current.data.copy_(saved.data)
def __repr__(self):
return repr(self.optimizer)
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed.distributed_c10d import _get_global_rank
import torch.distributed as dist
import math
from torch._six import inf
from torch.autograd import Variable
from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.pt.deepspeed_utils import see_memory_usage, is_model_parallel_parameter
#Toggle this to true to enable correctness test
#with gradient partitioning and without
pg_correctness_test = False
try:
from apex_C import flatten
from apex_C import unflatten
except ImportError:
try:
_ = warned_flatten
except NameError:
print(
"Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten."
)
warned_flatten = True
from torch._utils import _flatten_dense_tensors as flatten
from torch._utils import _unflatten_dense_tensors as unflatten
def input(msg):
return
def split_half_float_double(tensors):
dtypes = [
"torch.cuda.HalfTensor",
"torch.cuda.FloatTensor",
"torch.cuda.DoubleTensor"
]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
return buckets
def isclose(a, b, rtol=1e-09, atol=0.0):
return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)
def lcm(x, y):
from fractions import gcd # or can import gcd from `math` in Python 3
return x * y // gcd(x, y)
#create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(tensor_list, alignment, pg):
num_elements = 0
for tensor in tensor_list:
num_elements = num_elements + tensor.numel()
remaining = num_elements % alignment
if remaining:
elements_to_add = alignment - remaining
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
num_elements = num_elements + elements_to_add
else:
padded_tensor_list = tensor_list
return _flatten_dense_tensors(padded_tensor_list)
def move_to_cpu(tensor_list):
for tensor in tensor_list:
tensor.data = tensor.data.cpu()
class FP16_DeepSpeedZeroOptimizer(object):
"""
DeepSpeedZeroOptimizer designed to reduce the memory footprint
required for training large deep learning models.
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
For usage examples, refer to TODO: DeepSpeed Tutorial
"""
def __init__(self,
init_optimizer,
timers,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
contiguous_gradients=True,
reduce_bucket_size=500000000,
allgather_bucket_size=5000000000,
dp_process_group=None,
reduce_scatter=True,
overlap_comm=False,
mpu=None,
clip_grad=0.0):
if dist.get_rank() == 0:
print(f"Reduce bucket size {reduce_bucket_size}")
print(f"Allgather bucket size {allgather_bucket_size}")
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add ne552w fused optimizer later
# differences from apex.fp16_utils:
# - assume all model params in fp16
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
self.timers = timers
self.reduce_scatter = reduce_scatter
self.overlap_comm = overlap_comm
self.dp_process_group = dp_process_group
self.partition_count = dist.get_world_size(group=self.dp_process_group)
if mpu is None:
self.model_parallel_group = None
self.model_parallel_rank = 0
else:
self.model_parallel_group = mpu.get_model_parallel_group()
self.model_parallel_rank = mpu.get_model_parallel_rank()
self.overflow = False
self.clip_grad = clip_grad
# param flattened by groups
self.fp16_groups = []
self.fp16_groups_flat = []
#param partitioned by data parallel degree
#this will contain a list of equal sized tensors
#each of which will be updated by a different process
self.parallel_partitioned_fp16_groups = []
#a single 32-bit partition of the parallel partitioned parameters
#that this process will update
self.single_partition_of_fp32_groups = []
#param partition info
#These are the parameters in each group that will not be updated by this process directly
self.params_not_in_partition = []
#These are the parameters that will be updated by this process directly
self.params_in_partition = []
#Offset from the first paramter in the the self.params_in_partition
#the parameter boundaries may not align with partition boundaries
#so we need to keep track of the offset
self.first_offset = []
#number of elements per partition in each group
self.partition_size = []
partition_id = dist.get_rank(group=self.dp_process_group)
self.all_reduce_print = False
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
self.fp16_groups.append(param_group['params'])
#not sure why apex was cloning the weights before flattening
#removing cloning here
see_memory_usage(f"Before moving param group {i} to CPU")
#move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.fp16_groups[i])
see_memory_usage(f"After moving param group {i} to CPU")
#create flat buffer in CPU and move to GPU
self.fp16_groups_flat.append(
flatten_dense_tensors_aligned(
self.fp16_groups[i],
dist.get_world_size(group=self.dp_process_group),
self.dp_process_group).cuda(torch.cuda.current_device()))
see_memory_usage(f"After flattening and moving param group {i} to GPU")
if dist.get_rank(group=self.dp_process_group) == 0:
see_memory_usage(
f"After Flattening and after emptying param group {i} cache")
print("")
# set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
#divide the flat weights into near equal paritition equal to the data parallel degree
#each process will compute on a different part of the partition
data_parallel_partitions = self.get_data_parallel_partitions(
self.fp16_groups_flat[i])
self.parallel_partitioned_fp16_groups.append(data_parallel_partitions)
# a partition of the fp32 master weights that will be updated by this process
self.single_partition_of_fp32_groups.append(
self.parallel_partitioned_fp16_groups[i]
[partition_id].clone().float().detach())
# modify optimizer of have flat master weight
self.single_partition_of_fp32_groups[
i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.single_partition_of_fp32_groups[i]]
partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(
group=self.dp_process_group)
params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(self.fp16_groups[i], partition_size, partition_id)
self.partition_size.append(partition_size)
self.params_in_partition.append(params_in_partition)
self.params_not_in_partition.append(params_not_in_partition)
self.first_offset.append(first_offset)
self.reduce_bucket_size = int(reduce_bucket_size)
self.allgather_bucket_size = int(allgather_bucket_size)
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.reduction_stream = torch.cuda.Stream()
self.callback_queued = False
self.param_dict = {}
#map between param_id and bool to specify if a param is in this partition
self.is_param_in_current_partition = {}
self.contiguous_gradients = contiguous_gradients
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.elements_in_ipg_bucket = 0
self.params_already_reduced = []
self._release_ipg_buffers()
self.previous_reduced_grads = None
#simplified param id
self.param_id = {}
count = 0
for i, params_group in enumerate(self.fp16_groups):
for param in params_group:
unique_id = id(param)
self.param_id[unique_id] = count
self.param_dict[count] = param
self.params_already_reduced.append(False)
count = count + 1
for param_group in self.params_in_partition:
for param in param_group:
self.is_param_in_current_partition[self.get_param_id(param)] = True
for param_group in self.params_not_in_partition:
for param in param_group:
self.is_param_in_current_partition[self.get_param_id(param)] = False
#mapping from parameter to partition that it belongs to
self.param_to_partition_ids = {}
#stores if a partition has been reduced in this step
self.is_partition_reduced = {}
#number of grads in partition that still need to be computed
self.remaining_grads_in_partition = {}
#total number of grads in partition
self.total_grads_in_partition = {}
#stores if a grad in a partition has been computed or not
self.is_grad_computed = {}
#stores the offset at which a parameter gradient needs to be inserted in a partition
self.grad_partition_insertion_offset = {}
#the offset in the gradient at which it must be inserted at the beginning of the paritition
self.grad_start_offset = {}
#will store the averaged gradients required by this parititon
self.averaged_gradients = {}
# store index of first parameter in each partition
self.first_param_index_in_partition = {}
#initializes all data structures for implementing gradient partitioning
self.initialize_gradient_partitioning_data_structures()
#resets the data structure value for the next backward propagation
self.reset_partition_gradient_structures()
#creates backward hooks for gradient partitioning
self.create_reduce_and_remove_grad_hooks()
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
if dynamic_loss_args is None:
self.loss_scaler = DynamicLossScaler()
else:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
self.dynamic_loss_scale = True
else:
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(scale=static_loss_scale)
self.cur_iter = 0
see_memory_usage("Before initializing optimizer states")
self.initialize_optimizer_states()
see_memory_usage("After initializing optimizer states")
if dist.get_rank() == 0:
print(f"optimizer state initialized")
if dist.get_rank(group=self.dp_process_group) == 0:
see_memory_usage(f"After initializing ZeRO optimizer")
print("")
def _release_ipg_buffers(self):
if self.contiguous_gradients:
self.ipg_buffer = None
self.grads_in_partition = None
self.grads_in_partition_offset = 0
def initialize_optimizer_states(self):
for i, group in enumerate(self.fp16_groups):
single_grad_partition = torch.zeros(
int(self.partition_size[i]),
dtype=self.single_partition_of_fp32_groups[i].dtype).cuda()
self.single_partition_of_fp32_groups[i].grad = single_grad_partition
self.optimizer.step()
for group in self.single_partition_of_fp32_groups:
group.grad = None
return
#########################################################################
#########################ZeRO Partition Gradients########################
#########################################################################
def get_first_param_index(self, group_id, param_group, partition_id):
for index, param in enumerate(param_group):
param_id = self.get_param_id(param)
if partition_id in self.param_to_partition_ids[group_id][param_id]:
return index
return None
def initialize_gradient_partitioning_data_structures(self):
total_partitions = dist.get_world_size(group=self.dp_process_group)
for i, param_group in enumerate(self.fp16_groups):
self.param_to_partition_ids[i] = {}
self.is_partition_reduced[i] = {}
self.total_grads_in_partition[i] = {}
self.remaining_grads_in_partition[i] = {}
self.is_grad_computed[i] = {}
self.grad_partition_insertion_offset[i] = {}
self.grad_start_offset[i] = {}
self.first_param_index_in_partition[i] = {}
for partition_id in range(total_partitions):
self.is_grad_computed[i][partition_id] = {}
self.grad_partition_insertion_offset[i][partition_id] = {}
self.grad_start_offset[i][partition_id] = {}
self.initialize_gradient_partition(i, param_group, partition_id)
self.is_partition_reduced[i][partition_id] = False
self.first_param_index_in_partition[i][
partition_id] = self.get_first_param_index(
i,
param_group,
partition_id)
def independent_gradient_partition_epilogue(self):
self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0)
self.reduce_ipg_grads()
self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0)
#if dist.get_rank() == 0:
# print("Params already reduced ", self.params_already_reduced)
for i in range(len(self.params_already_reduced)):
self.params_already_reduced[i] = False
if self.overlap_comm:
torch.cuda.synchronize()
for i, _ in enumerate(self.fp16_groups):
self.averaged_gradients[i] = self.get_flat_partition(
self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
return_tensor_list=True)
self._release_ipg_buffers()
see_memory_usage(f"End ipg_epilogue")
# resets all partition to no reduced
# sets remianing grads to the total number of grads in each partition
# set is grad computed to false for all grads in partition
def reset_partition_gradient_structures(self):
total_partitions = dist.get_world_size(group=self.dp_process_group)
for i, _ in enumerate(self.fp16_groups):
for partition_id in range(total_partitions):
self.is_partition_reduced[i][partition_id] = False
self.remaining_grads_in_partition[i][
partition_id] = self.total_grads_in_partition[i][partition_id]
for param_id in self.is_grad_computed[i][partition_id]:
self.is_grad_computed[i][partition_id][param_id] = False
def initialize_gradient_partition(self, i, param_group, partition_id):
def set_key_value_list(dictionary, key, value):
if key in dictionary:
dictionary[key].append(value)
else:
dictionary[key] = [value]
def increment_value(dictionary, key):
if key in dictionary:
dictionary[key] += 1
else:
dictionary[key] = 1
partition_size = self.partition_size[i]
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for param in param_group:
param_size = param.numel()
param_id = self.get_param_id(param)
if (current_index >= start_index and current_index < end_index):
set_key_value_list(self.param_to_partition_ids[i],
param_id,
partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][
param_id] = current_index - start_index
self.grad_start_offset[i][partition_id][param_id] = 0
elif start_index > current_index and start_index < (current_index +
param_size):
assert (first_offset==0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
set_key_value_list(self.param_to_partition_ids[i],
param_id,
partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][param_id] = 0
self.grad_start_offset[i][partition_id][param_id] = first_offset
current_index = current_index + param_size
def overlapping_partition_gradients_reduce_epilogue(self):
self.independent_gradient_partition_epilogue()
def create_reduce_and_remove_grad_hooks(self):
self.grad_accs = []
for i, param_group in enumerate(self.fp16_groups):
for param in param_group:
if param.requires_grad:
def wrapper(param, i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param, i)
grad_acc.register_hook(reduce_partition_and_remove_grads)
self.grad_accs.append(grad_acc)
wrapper(param, i)
def get_param_id(self, param):
unique_id = id(param)
return self.param_id[unique_id]
def report_ipg_memory_usage(self, tag, param_elems):
elem_count = self.elements_in_ipg_bucket + param_elems
percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size
see_memory_usage(
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
)
###############Idependent Partition Gradient ########################
def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads",
param.numel())
self.reduce_ipg_grads()
if self.contiguous_gradients and self.overlap_comm:
# Swap ipg_index between 0 and 1
self.ipg_index = 1 - self.ipg_index
self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads",
param.numel())
param_id = self.get_param_id(param)
assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported"
#keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
if self.contiguous_gradients:
new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(
0,
self.elements_in_ipg_bucket,
param.numel())
new_grad_tensor.copy_(param.grad.view(-1))
param.grad.data = new_grad_tensor.data.view_as(param.grad)
self.elements_in_ipg_bucket += param.numel()
self.grads_in_ipg_bucket.append(param.grad)
self.params_in_ipg_bucket.append((i, param, param_id))
self.report_ipg_memory_usage("End ipg_remove_grads", 0)
def print_rank_0(self, message):
if dist.get_rank() == 0:
print(message)
def average_tensor(self, tensor):
if self.overlap_comm:
torch.cuda.synchronize()
stream = self.reduction_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
if not self.reduce_scatter:
tensor.div_(dist.get_world_size(group=self.dp_process_group))
dist.all_reduce(tensor, group=self.dp_process_group)
return
# Accumulate destination ranks and bucket offsets for each gradient slice.
# Note: potential future optimization, record access pattern of parameters
# in backward pass and partition gradients w.r.t. access pattern so that our
# bucket is guaranteed to be contiguous w.r.t. ranks
rank_and_offsets = []
curr_size = 0
prev_id = -1
for i, param, param_id in self.params_in_ipg_bucket:
partition_ids = self.param_to_partition_ids[i][param_id]
partition_size = self.partition_size[i]
# Get all partition ids + their offsets
partition_ids_w_offsets = []
for partition_id in partition_ids:
offset = self.grad_start_offset[i][partition_id][param_id]
partition_ids_w_offsets.append((partition_id, offset))
partition_ids_w_offsets.sort(key=lambda t: t[1])
# Calculate rank and offsets for grad slices
for idx in range(len(partition_ids_w_offsets)):
partition_id, offset = partition_ids_w_offsets[idx]
# Calculate numel for grad slice depending on partition location
if idx == len(partition_ids_w_offsets) - 1:
# Last partition_id uses its own offset
numel = param.numel() - offset
else:
# Set numel to next partition's offset
numel = partition_ids_w_offsets[idx + 1][1] - offset
# Merge bucket ranges if they belong to the same rank
if partition_id == prev_id:
prev_pid, prev_size, prev_numel = rank_and_offsets[-1]
rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + numel)
else:
rank_and_offsets.append((partition_id, curr_size, numel))
curr_size += numel
prev_id = partition_id
tensor.div_(dist.get_world_size(group=self.dp_process_group))
async_handles = []
for dst, bucket_offset, numel in rank_and_offsets:
grad_slice = tensor.narrow(0, int(bucket_offset), int(numel))
dst_rank = _get_global_rank(self.dp_process_group, dst)
async_handle = dist.reduce(grad_slice,
dst=dst_rank,
group=self.dp_process_group,
async_op=True)
async_handles.append(async_handle)
for handle in async_handles:
handle.wait()
def copy_grads_in_partition(self, param):
if self.grads_in_partition is None:
self.grads_in_partition_offset = 0
total_size = 0
for group in self.params_in_partition:
for param_in_partition in group:
total_size += param_in_partition.numel()
see_memory_usage(f"before copying {total_size} gradients into partition")
self.grads_in_partition = torch.empty(int(total_size),
dtype=torch.half).cuda()
see_memory_usage(f"after copying {total_size} gradients into partition")
#The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer
new_grad_tensor = self.grads_in_partition.narrow(0,
self.grads_in_partition_offset,
param.numel())
new_grad_tensor.copy_(param.grad.view(-1))
param.grad.data = new_grad_tensor.data.view_as(param.grad)
self.grads_in_partition_offset += param.numel()
def reduce_ipg_grads(self):
if self.overlap_comm:
stream = self.reduction_stream
else:
stream = torch.cuda.current_stream()
if self.contiguous_gradients:
self.average_tensor(self.ipg_buffer[self.ipg_index])
else:
self.buffered_reduce_fallback(
None,
self.grads_in_ipg_bucket,
elements_per_buffer=self.elements_in_ipg_bucket)
with torch.cuda.stream(stream):
for _, param, param_id in self.params_in_ipg_bucket:
self.params_already_reduced[param_id] = True
if not self.is_param_in_current_partition[param_id]:
if self.overlap_comm and self.contiguous_gradients is False:
# Clear the previous grads during the next reduction
# to avoid clearing them before the reduction is complete.
if self.previous_reduced_grads is None:
self.previous_reduced_grads = []
self.previous_reduced_grads.append(param)
else:
param.grad = None
elif self.contiguous_gradients:
self.copy_grads_in_partition(param)
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.elements_in_ipg_bucket = 0
#####################################################################
def reduce_ready_partitions_and_remove_grads(self, param, i):
self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
def zero_reduced_gradients(self, partition_id, i):
def are_all_related_partitions_reduced(params_id):
for partition_id in self.param_to_partition_ids[i][params_id]:
if not self.is_partition_reduced[i][partition_id]:
return False
return True
for params_id in self.is_grad_computed[i][partition_id]:
if are_all_related_partitions_reduced(params_id):
self.param_dict[params_id].grad = None
def flatten_and_print(self, message, tensors, start=0, n=5):
flatten_tensor = _flatten_dense_tensors(tensors)
def print_func():
print(flatten_tensor.contiguous().view(-1).narrow(0, start, n))
self.sequential_execution(print_func, message)
def get_grads_to_reduce(self, i, partition_id):
def get_reducable_portion(key):
grad = self.param_dict[key].grad
total_elements = grad.numel()
start = self.grad_start_offset[i][partition_id][key]
num_elements = min(
total_elements - start,
self.partition_size[i] -
self.grad_partition_insertion_offset[i][partition_id][key])
if not pg_correctness_test:
if num_elements == total_elements:
return grad
else:
return grad.contiguous().view(-1).narrow(0,
int(start),
int(num_elements))
else:
if num_elements == total_elements:
return grad.clone()
else:
return grad.clone().contiguous().view(-1).narrow(
0,
int(start),
int(num_elements))
grads_to_reduce = []
for key in self.is_grad_computed[i][partition_id]:
grad = get_reducable_portion(key)
grads_to_reduce.append(grad)
return grads_to_reduce
def sequential_execution(self, function, message, group=None):
if group is None:
group = self.dp_process_group
if dist.get_rank(group=group) == 0:
print(message)
for id in range(dist.get_world_size(group=group)):
if id == dist.get_rank(group=group):
function()
dist.barrier(group=group)
def set_none_gradients_to_zero(self, i, partition_id):
for param_id in self.is_grad_computed[i][partition_id]:
param = self.param_dict[param_id]
if param.grad is None:
param.grad = torch.zero_like(param)
######################Reduction Related Methods##############################
def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None):
rank = None
tensor = flatten(bucket)
tensor_to_allreduce = tensor
if pg_correctness_test:
allreduce_always_fp32 = True
if allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group))
if rank is None:
# "All Reducing"
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
else:
global_rank = _get_global_rank(self.dp_process_group, rank)
dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group)
if allreduce_always_fp32 and tensor is not tensor_to_allreduce:
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
tensor.copy_(tensor_to_allreduce)
return tensor
#if rank is specified do a reduction instead of an allreduce
def allreduce_and_copy(self, small_bucket, rank=None, log=None):
if self.overlap_comm:
torch.cuda.synchronize()
if self.previous_reduced_grads is not None:
# previous_reduced_grads has the previous reduced grads,
# now it is safe to clear.
for param in self.previous_reduced_grads:
param.grad = None
self.previous_reduced_grads = None
stream = self.reduction_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log)
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)):
buf.copy_(synced)
def allreduce_no_retain(self,
bucket,
numel_per_bucket=500000000,
rank=None,
log=None):
small_bucket = []
numel = 0
for tensor in bucket:
small_bucket.append(tensor)
numel = numel + tensor.numel()
if numel > numel_per_bucket:
self.allreduce_and_copy(small_bucket, rank=rank, log=None)
small_bucket = []
if len(small_bucket) > 0:
self.allreduce_and_copy(small_bucket, rank=rank, log=log)
#allows using reduction of gradients instead of using all_reduce
def buffered_reduce_fallback(self,
rank,
grads,
elements_per_buffer=500000000,
log=None):
split_buckets = split_half_float_double(grads)
for i, bucket in enumerate(split_buckets):
self.allreduce_no_retain(bucket,
numel_per_bucket=elements_per_buffer,
rank=rank,
log=log)
#############################################################################
#############################################################################
#############################################################################
#views the tensor as multiple partitions and returns
#those partitions
def get_data_parallel_partitions(self, tensor):
partitions = []
dp = dist.get_world_size(group=self.dp_process_group)
dp_id = dist.get_rank(group=self.dp_process_group)
total_num_elements = tensor.numel()
base_size = total_num_elements // dp
remaining = total_num_elements % dp
start = 0
for id in range(dp):
partition_size = base_size
if id < remaining:
partition_size = partition_size + 1
partitions.append(tensor.narrow(0, start, partition_size))
start = start + partition_size
return partitions
def get_partition_info(self, tensor_list, partition_size, partition_id):
params_in_partition = []
params_not_in_partition = []
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for tensor in tensor_list:
tensor_size = tensor.numel()
if (current_index >= start_index and current_index < end_index):
params_in_partition.append(tensor)
elif start_index > current_index and start_index < (current_index +
tensor_size):
params_in_partition.append(tensor)
assert (first_offset==0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
else:
params_not_in_partition.append(tensor)
current_index = current_index + tensor_size
return params_in_partition, params_not_in_partition, first_offset
def zero_grad(self, set_grads_to_None=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def _model_parallel_all_reduce(self, tensor, op):
""" Perform all reduce within model parallel group, if any.
"""
if self.model_parallel_group is None:
torch.distributed.all_reduce(tensor=tensor, op=op)
else:
torch.distributed.all_reduce(tensor=tensor,
op=op,
group=self.model_parallel_group)
def get_grad_norm_direct(self, gradients, params, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=self.dp_process_group)
# Take max across all GPUs.
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
#if dist.get_rank() == 0:
# print(f"Total Norm begining {total_norm}")
for g, p in zip(gradients, params):
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
#creates a flat fused tensor from the tensor list starting at the first_offset
#in the first tensor of the list. If there are not enough elements in the tensor
#list then the flat tensor will be padded with zeros
def get_flat_partition(self,
tensor_list,
first_offset,
partition_size,
return_tensor_list=False):
flat_tensor_list = []
current_size = 0
for i, tensor in enumerate(tensor_list):
if tensor.grad is None:
continue
tensor = tensor.grad
num_elements = tensor.numel()
tensor_offset = 0
#we need to offset to get to the right element
if i == 0 and first_offset > 0:
tensor_offset = first_offset
num_elements = num_elements - tensor_offset
#we dont need all elements of the tensor
if num_elements > (partition_size - current_size):
num_elements = partition_size - current_size
#we need a narrow view of the tensor based on the tensor offset and number of elements that
#we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel():
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
0,
int(tensor_offset),
int(num_elements)))
else:
flat_tensor_list.append(tensor)
current_size = current_size + num_elements
#this means its the last partition and does not align with the dp boundary. We need to pad before flattening
if current_size < partition_size:
flat_tensor_list.append(
torch.zeros(int(partition_size - current_size),
dtype=tensor_list[0].dtype,
device=tensor_list[0].device))
if return_tensor_list:
return flat_tensor_list
return _flatten_dense_tensors(flat_tensor_list)
def free_grad_in_param_list(self, param_list):
for p in param_list:
p.grad = None
def step(self, closure=None):
"""
Not supporting closure.
"""
see_memory_usage(f"In step before checking overflow")
# First compute norm for all group so we know if there is overflow
self.check_overflow()
timers = self.timers
prev_scale = self.loss_scale
self._update_scale(self.overflow)
if self.overflow:
see_memory_usage('After overflow before clearing gradients')
self.zero_grad()
see_memory_usage('After overflow after clearing gradients')
print(
"[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(dist.get_rank(),
prev_scale,
self.loss_scale))
timers('optimizer_step').start()
timers('optimizer_step').stop()
timers('optimizer_allgather').start()
timers('optimizer_allgather').stop()
return
norm_groups = []
single_partition_grad_groups = []
skip = False
partition_id = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups):
norm_groups.append(
self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i]))
#free gradients for all the prameters that are not updated by this process
self.free_grad_in_param_list(self.params_not_in_partition[i])
#create a flat gradients for parameters updated by this process
# If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
single_grad_partition = flatten_dense_tensors_aligned(
self.averaged_gradients[i],
int(self.partition_size[i]),
self.dp_process_group).to(
self.single_partition_of_fp32_groups[i].dtype)
else:
single_grad_partition = _flatten_dense_tensors(
self.averaged_gradients[i]).to(
self.single_partition_of_fp32_groups[i].dtype)
assert single_grad_partition.numel() == self.partition_size[i], \
"averaged gradients have different number of elements that partition size {} {} {} {}".format(single_grad_partition.numel(), self.partition_size[i], i, partition_id)
self.single_partition_of_fp32_groups[i].grad = single_grad_partition
#release all the gradient since we have already created a necessary copy in dp_grad_partition
self.free_grad_in_param_list(self.params_in_partition[i])
self.averaged_gradients[i] = None
single_partition_grad_groups.append(single_grad_partition)
self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
timers('optimizer_step').start()
self.optimizer.step()
#get rid of the fp32 gradients. Not needed anymore
for group in self.single_partition_of_fp32_groups:
group.grad = None
for i in range(len(norm_groups)):
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)
timers('optimizer_step').stop()
timers('optimizer_allgather').start()
#gather the updated weights from everyone
for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups):
#Sequential AllGather Best of both worlds
dp_world_size = dist.get_world_size(group=self.dp_process_group)
num_shards = max(
1,
partitioned_params[partition_id].numel() * dp_world_size //
self.allgather_bucket_size)
if num_shards == 1:
dist.all_gather(partitioned_params,
partitioned_params[partition_id],
group=self.dp_process_group)
else:
shard_size = partitioned_params[partition_id].numel() // num_shards
num_elements = shard_size
for shard_id in range(num_shards):
#boundary condition
#TODO: Check correctness of boundary condition
if shard_id == (num_shards - 1):
if shard_size * num_shards >= partitioned_params[
partition_id].numel():
break
else:
num_elements = partitioned_params[partition_id].numel(
) - shard_id * shard_size
shard_list = []
for dp_id in range(dp_world_size):
curr_shard = partitioned_params[dp_id].narrow(
0,
shard_id * shard_size,
num_elements)
shard_list.append(curr_shard)
dist.all_gather(shard_list,
shard_list[partition_id],
group=self.dp_process_group)
timers('optimizer_allgather').stop()
# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
see_memory_usage('After zero_optimizer step')
return
def unscale_and_clip_grads(self, grad_groups_flat, norm_groups):
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)
# compute combined scale factor for this group
combined_scale = self.loss_scale
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.loss_scale
for grad in grad_groups_flat:
if isinstance(grad, list):
sub_partitions = grad
for g in sub_partitions:
g.data.mul_(1. / combined_scale)
else:
grad.data.mul_(1. / combined_scale)
def _check_overflow(self, partition_gradients=True):
self.overflow = self.has_overflow(partition_gradients)
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params, is_grad_list=False):
for p in params:
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
return True
return False
def has_overflow_partitioned_grads_serial(self):
for i in range(len(self.fp16_groups)):
for j, grad in enumerate(self.averaged_gradients[i]):
if grad is not None and self._has_inf_or_nan(grad.data, j):
return True
return False
def has_overflow(self, partition_gradients=True):
if partition_gradients:
overflow = self.has_overflow_partitioned_grads_serial()
overflow_gpu = torch.cuda.ByteTensor([overflow])
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=self.dp_process_group)
else:
params = []
for group in self.fp16_groups:
for param in group:
params.append(param)
overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients)
overflow_gpu = torch.cuda.ByteTensor([overflow])
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
self._model_parallel_all_reduce(tensor=overflow_gpu,
op=torch.distributed.ReduceOp.MAX)
overflow = overflow_gpu[0].item()
return bool(overflow)
# `x` is a torch.Tensor
@staticmethod
def _has_inf_or_nan(x, j=None):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
if dist.get_rank() == 0 and j is not None:
_handle_overflow(cpu_sum, x, j)
return True
return False
def backward(self, loss, retain_graph=False):
"""
:attr:`backward` performs the following steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
if self.contiguous_gradients:
self.ipg_buffer = []
buf_0 = torch.empty(self.reduce_bucket_size, dtype=torch.half).cuda()
self.ipg_buffer.append(buf_0)
# Use double buffers to avoid data access conflict when overlap_comm is enabled.
if self.overlap_comm:
buf_1 = torch.empty(self.reduce_bucket_size, dtype=torch.half).cuda()
self.ipg_buffer.append(buf_1)
self.ipg_index = 0
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
def check_overflow(self, partition_gradients=True):
self._check_overflow(partition_gradients)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict[
'single_partition_of_fp32_groups'] = self.single_partition_of_fp32_groups
state_dict['partition_count'] = self.partition_count
return state_dict
def load_state_dict(self, state_dict, load_optimizer_states=True):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 1 if changing DP degree and option 2 otherwise.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
if 'partition_count' in state_dict and state_dict[
'partition_count'] == self.partition_count:
# Use option 2
for current, saved in zip(self.single_partition_of_fp32_groups, state_dict['single_partition_of_fp32_groups']):
current.data.copy_(saved.data)
else:
# Use option 1
partition_id = dist.get_rank(group=self.dp_process_group)
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp32_partition.data.copy_(fp16_partitions[partition_id].data)
def _handle_overflow(cpu_sum, x, i):
import math
rank = torch.distributed.get_rank()
if rank == 0:
t_i = -1
for v_i, v in enumerate(x.data.contiguous().view(-1)):
if not math.isfinite(float(v)):
t_i = v_i
break
print(
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
)
import math
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from collections import defaultdict
from deepspeed.pt.zero_utils import _initialize_parameter_parallel_groups, \
pprint
from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow
def flatten_dense_tensors_sub_partition_aligned(tensor_list,
dp,
max_elements_per_comm,
pg):
num_elements = 0
for tensor in tensor_list:
num_elements = num_elements + tensor.numel()
pprint("Total number of elements in model: {}, max elements per com: {}".format(
num_elements,
max_elements_per_comm))
max_elements_per_comm = min(max_elements_per_comm, num_elements)
sub_partition_size = int(max_elements_per_comm // dp)
alignment = sub_partition_size
# if alignment == 0:
# # number of elements not divisible by dp, outside range and small model must pad with zeroes
# pad_tensor = torch.zeros(max_elements_per_comm,
# device=tensor_list[0].device,
# dtype=tensor_list[0].dtype)
# return _flatten_dense_tensors(pad_tensor)
remaining = int(num_elements % alignment)
# ensure we have equal sized sub-partitions
elements_to_add = 0
if remaining:
elements_to_add = alignment - remaining
# adding padded tensor later after we check comm alignment
pprint("adding pad tensor for alignment, {} + {}->{}".format(
num_elements,
elements_to_add,
num_elements + elements_to_add))
#num_elements = num_elements + elements_to_add
else:
padded_tensor_list = tensor_list
num_partitions = int((num_elements + elements_to_add) // sub_partition_size)
assert (num_elements + elements_to_add) % sub_partition_size == 0, "num elements should be " \
"aligned by sub partition " \
"size"
num_comm_intervals = int(num_partitions // dp)
partition_remaining = int(num_partitions % dp)
pprint("num_comm_intervals={}, partition_remaining={}".format(
num_comm_intervals,
partition_remaining))
if partition_remaining != 0:
pprint("adding pad tensor and/or extra sub partition")
# add pad tensor for alignment of comm interval, this overrules previous possibly sub-partition alignment
num_comm_intervals += 1
aligned_comm_elements = num_comm_intervals * sub_partition_size * dp
elements_to_add = aligned_comm_elements - num_elements
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
pprint("adding pad tensor and/or extra sub partition, {} + {}->{}".format(
num_elements,
elements_to_add,
num_elements + elements_to_add))
num_elements += elements_to_add
elif elements_to_add > 0:
# add pad tensor for just alignment of sub-partition
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
num_elements += elements_to_add
if pg is None or dist.get_rank(group=pg) == 0:
print("Number of Elements (w. padding) is ", num_elements)
padded_num_elems = 0
for p in padded_tensor_list:
padded_num_elems += p.numel()
assert num_elements == padded_num_elems, "{} != {}, rank={}".format(num_elements, padded_num_elems, dist.get_rank())
return _flatten_dense_tensors(padded_tensor_list)
def _single_range_check(current_index, start_index, end_index, tensor_size):
offset = 0
if (current_index >= start_index) and (current_index < end_index):
# Fully inside bounds
return True, offset
elif (start_index > current_index) and (start_index < (current_index + tensor_size)):
# Partially contained, compute offset
offset = start_index - current_index
return True, offset
else:
return False, offset
def _range_check(current_index, element_intervals, tensor_size):
results = []
for comm_idx, interval in enumerate(element_intervals):
start_index, end_index = interval
contained, offset = _single_range_check(current_index, start_index, end_index, tensor_size)
if contained:
results.append((contained, offset, comm_idx))
if len(results) == 0:
return [(False, 0, -1)]
return results
class FP16_DeepSpeedZeroOptimizer_Stage1(object):
"""
FP16_DeepSpeedZeroOptimizer_Stage1 designed to reduce the memory footprint
required for training large deep learning models.
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
This version aligns with stage-1 in the paper above.
"""
def __init__(self,
init_optimizer,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
dp_process_group=None,
partition_size=None,
mpu=None,
all_gather_partitions=True,
allgather_size=500000000,
clip_grad=0.0,
max_elements_per_comm=5e8):
if dp_process_group is not None and partition_size is not None:
raise ValueError("Cannot specify both dp_process_group "
"and partition size")
if dp_process_group is None:
dp_process_group = _initialize_parameter_parallel_groups(partition_size)
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
self.verbose = verbose
self.dp_process_group = dp_process_group
# TODO: automatically turn off if #params > some_limit
self.all_gather_partitions = all_gather_partitions
self.allgather_size = allgather_size
self.max_elements_per_comm = max_elements_per_comm
print("max_elements_per_comm={}".format(max_elements_per_comm))
# param flattened by groups
self.fp16_groups = []
self.fp16_groups_flat = []
# Setup bookkeeping data structures depending on partitioning type
# parallel_sub_partitioned_fp16_groups[group-idx] -> [comm-ids] -> [rank-ids]
self.parallel_sub_partitioned_fp16_groups = []
# same underlying data as above but viewed as: [groups] -> [rank-ids] -> [comm-ids]
self.parallel_comm_sub_partitioned_fp16_groups = []
# 32-bit sub-partitions of the parallel partitioned parameters
# that this process will update
self.local_sub_partitions_of_fp32_groups = []
# param partition info
# parameters in each group that will not be updated by this process directly
self.params_not_local = []
# parameters that will be updated by this process directly
self.params_in_rank_sub_partitions = []
# parameter offsets for parameters in sub-partitions. Parameter
# boundaries may not align with sub-partition boundaries
# so we need to keep track of the offsets
self.params_in_rank_sub_partitions_offsets = []
# number of elements per sub-partition in each group
self.sub_partition_sizes = []
# number of communication intervals for each group
self.num_comm_intervals_per_group = []
local_rank = dist.get_rank(group=self.dp_process_group)
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
self.fp16_groups.append(param_group['params'])
# flattens all tensors into single 1d tensor aligned with sub-partition size for later dividing
# RS: create aligned sub-partitions
self.fp16_groups_flat.append(
flatten_dense_tensors_sub_partition_aligned(
tensor_list=self.fp16_groups[i],
dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=self.max_elements_per_comm,
pg=self.dp_process_group))
# TODO: I don't think this does anything?
# set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
# divide the flat weights into near equal partition equal to the data parallel degree
# each process will compute on a different part of the partition
# RS: split into two layer list -> [comm-id] -> [sub-partitions per rank]
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
self.get_data_parallel_sub_partitions(
tensor=self.fp16_groups_flat[i],
max_elements_per_comm=self.max_elements_per_comm,
world_size=dist.get_world_size(
group=self.dp_process_group),
dp_process_group=self.dp_process_group
)
self.parallel_comm_sub_partitioned_fp16_groups.append(
comm_partitions) # comm -> rank
self.parallel_sub_partitioned_fp16_groups.append(
dp_sub_partitions) # rank -> comm
self.sub_partition_sizes.append(sub_partition_size)
self.num_comm_intervals_per_group.append(num_comm_intervals)
# data_parallel_partitions = self.get_data_parallel_partitions(self.fp16_groups_flat[i])
# self.parallel_partitioned_fp16_groups.append(data_parallel_partitions)
# a partition of the fp32 master weights that will be updated by this process
# RS: store/detach/cast our local sub-partitions
local_sub_partitions = []
for sub_partition in self.parallel_sub_partitioned_fp16_groups[i][
local_rank]:
fp32_sub_partition = sub_partition.clone().float().detach()
fp32_sub_partition.requires_grad = True
local_sub_partitions.append(fp32_sub_partition)
self.local_sub_partitions_of_fp32_groups.append(local_sub_partitions)
# modify optimizer of have flat master weight
# self.single_partition_of_fp32_groups[i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = self.local_sub_partitions_of_fp32_groups[i]
# RS: divide up the sub-partitions and keep track of offsets for each param
# partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(group=self.dp_process_group)
params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, \
params_not_local = self.get_all_sub_partition_info(
tensor_list=self.fp16_groups[i],
all_element_intervals=element_intervals,
local_rank=local_rank,
world_size=dist.get_world_size(group=self.dp_process_group)
)
self.params_in_rank_sub_partitions.append(params_in_rank_sub_partition)
self.params_not_local.append(params_not_local)
self.params_in_rank_sub_partitions_offsets.append(
params_in_rank_sub_partitions_offsets)
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
if dynamic_loss_args is None:
self.loss_scaler = DynamicLossScaler()
else:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
self.dynamic_loss_scale = True
else:
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(scale=static_loss_scale)
self.cur_iter = 0
self.mpu = mpu
self.clip_grad = clip_grad
self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups,
mpu=self.mpu,
zero_reduce_scatter=True)
@staticmethod
def get_data_parallel_sub_partitions(tensor,
max_elements_per_comm,
world_size,
dp_process_group=None):
total_num_elements = tensor.numel()
# if total elements is less than our max, revert to splitting into dp partitions
max_elements_per_comm = min(total_num_elements, max_elements_per_comm)
sub_partition_size = int(max_elements_per_comm // world_size)
# Ensure partition alignment was done correctly
num_sub_partitions = int(total_num_elements // sub_partition_size)
assert total_num_elements % sub_partition_size == 0, "{} % {} != 0".format(total_num_elements, sub_partition_size)
# Ensure comm interval alignment was done correctly.
num_comm_intervals = int(num_sub_partitions // world_size)
assert num_sub_partitions % world_size == 0, "{} % {} != 0".format(num_sub_partitions, world_size)
if not dist.is_initialized() or dist.get_rank(group=dp_process_group) == 0:
print("**** partition info:")
print("\t total_num_elements=", total_num_elements)
print("\t world_size=", world_size)
print("\t max_elements_per_comm=", max_elements_per_comm)
print("\t sub_partition_size=", sub_partition_size)
print("\t num_sub_partitions=", num_sub_partitions)
print("\t num_comm_intervals=", num_comm_intervals)
print("****")
# [comm_id] -> [rank]
comm_partitions = []
for _ in range(num_comm_intervals):
comm_partitions.append([])
start = 0
comm_id = 0
element_intervals = defaultdict(
list) # [rank] -> [(start,end), (start,end), ...]
for idx in range(num_sub_partitions):
rank_id = idx % world_size
sub_partition = tensor.narrow(0, start, sub_partition_size)
element_intervals[rank_id].append((start, start + sub_partition_size))
comm_partitions[comm_id].append(sub_partition)
start = start + sub_partition_size
if rank_id == (world_size - 1):
comm_id += 1
# [rank] -> [comm_id]
sub_partitions = []
for _ in range(world_size):
sub_partitions.append([])
for comm_id, partitions in enumerate(comm_partitions):
for rank_id, partition in enumerate(partitions):
sub_partitions[rank_id].append(partition)
return comm_partitions, sub_partitions, element_intervals, sub_partition_size, num_comm_intervals
@staticmethod
def get_all_sub_partition_info(tensor_list,
all_element_intervals,
local_rank,
world_size):
params_not_local = []
# [rank] -> [comm-id] -> [param/offset]
params_in_rank_sub_partition = []
params_in_rank_sub_partitions_offsets = []
for rank in range(world_size):
params_in_local_sub_partition = []
local_sub_partition_offsets = []
comm_tensor_list = []
comm_offset_list = []
current_index = 0
prev_comm_idx = 0
for iii, tensor in enumerate(tensor_list):
tensor_size = tensor.numel()
#if local_rank == 0:
# #print("rank={}, current_index={}, tensor_size={}, tensor-idx={}".format(rank,
# current_index, tensor_size, iii))
results_list = _range_check(current_index,
all_element_intervals[rank],
tensor_size)
for contained, offset, comm_idx in results_list:
#if local_rank == 0:
# print("rank={}, contained={}, offset={}, comm_idx={}".format(rank, contained,
# offset, comm_idx))
if contained:
if prev_comm_idx != comm_idx:
params_in_local_sub_partition.append(comm_tensor_list)
comm_tensor_list = []
local_sub_partition_offsets.append(comm_offset_list)
comm_offset_list = []
comm_tensor_list.append(tensor)
comm_offset_list.append(offset)
prev_comm_idx = comm_idx
elif rank == local_rank:
params_not_local.append(tensor)
current_index = current_index + tensor_size
#assert len(comm_tensor_list) > 0
#assert len(comm_offset_list) > 0
params_in_local_sub_partition.append(comm_tensor_list)
local_sub_partition_offsets.append(comm_offset_list)
params_in_rank_sub_partition.append(params_in_local_sub_partition)
params_in_rank_sub_partitions_offsets.append(local_sub_partition_offsets)
return params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local
@staticmethod
def get_flat_sub_partitions(comm_tensor_list,
comm_param_offsets,
sub_partition_size,
dtype,
num_comm_intervals=None,
default_device=None,
return_partition_params=False):
partition_params = []
final_param_offsets = []
flat_sub_partitions = []
for tensor_list, param_offsets in zip(comm_tensor_list, comm_param_offsets):
flat_tensor_list = []
current_size = 0
my_offsets = []
my_params = []
if dtype is None:
dtype = tensor_list[0].dtype
for i, tensor in enumerate(tensor_list):
if tensor.grad is None:
tensor.grad = torch.zeros(tensor.size(),
dtype=tensor.dtype,
device=tensor.device)
param = tensor
tensor = tensor.grad
num_elements = tensor.numel()
tensor_offset = 0
#we need to offset to get to the right element
if i == 0 and param_offsets[i] > 0:
tensor_offset = param_offsets[i]
num_elements = num_elements - tensor_offset
# We don't need all elements of the tensor if this tensor is
# larger than we have space for in our curr sub-partition
if num_elements > (sub_partition_size - current_size):
num_elements = sub_partition_size - current_size
#we need a narrow view of the tensor based on the tensor offset and number of elements that
#we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel():
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
0,
int(tensor_offset),
int(num_elements)).to(dtype))
else:
flat_tensor_list.append(tensor.to(dtype))
my_params.append(param)
#remember offset into partition and #elems for this tensor
my_offsets.append((current_size, num_elements))
current_size = current_size + num_elements
#this means its the last partition and does not align with the dp boundary. We need to pad before flattening
if current_size < sub_partition_size:
my_offsets.append((None, None))
my_params.append(None)
if len(tensor_list) == 0:
assert default_device != None
flat_tensor_list.append(
torch.zeros(int(sub_partition_size - current_size),
dtype=dtype,
device=default_device))
else:
flat_tensor_list.append(
torch.zeros(int(sub_partition_size - current_size),
dtype=dtype,
device=tensor_list[0].device))
partition_params.append(my_params) #flat_tensor_list)
final_param_offsets.append(my_offsets)
assert len(flat_tensor_list) == len(my_offsets), "{} {}".format(len(flat_tensor_list), len(my_offsets))
flat_sub_partitions.append(_flatten_dense_tensors(flat_tensor_list))
if num_comm_intervals is not None and len(
flat_sub_partitions) < num_comm_intervals:
#print("padding w. sub partitions to ensure uniform communication")
device = flat_sub_partitions[0].device
for _ in range(num_comm_intervals - len(flat_sub_partitions)):
flat_sub_partitions.append(
torch.zeros(int(sub_partition_size),
dtype=dtype,
device=device))
partition_params.append([None])
final_param_offsets.append([(None, None)])
if return_partition_params:
assert len(flat_sub_partitions) == len(partition_params)
assert len(partition_params) == len(final_param_offsets), "{} {}".format(len(partition_params), len(final_param_offsets))
return flat_sub_partitions, partition_params, final_param_offsets
return flat_sub_partitions
def zero_grad(self, set_grads_to_None=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def free_grad_in_param_list(self, param_list):
for p in param_list:
if isinstance(p, list):
for _p in p:
_p.grad = None
else:
p.grad = None
def reduce_scatter_gradients(self,
postscale_gradients,
gradient_predivide_factor,
gradient_average):
world_size = dist.get_world_size(group=self.dp_process_group)
local_rank = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups):
partition_param_map = {}
param_partition_map = {}
my_params = set()
# [rank] -> [comm] -> partition
num_comm_intervals = self.num_comm_intervals_per_group[i]
all_sub_partitions = []
for rank in range(world_size):
# gsp is list of partitions indexed by comm_idx
#FIXME: currently hardcoding fp16, should infer dtype
grad_sub_partitions, partition_params, param_offsets = self.get_flat_sub_partitions(
comm_tensor_list=self.params_in_rank_sub_partitions[i][rank],
comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i][rank],
sub_partition_size=self.sub_partition_sizes[i],
dtype=torch.half, #self.params_in_rank_sub_partitions[i][rank][0][0].dtype,
num_comm_intervals=self.num_comm_intervals_per_group[i],
default_device='cuda', #self.params_in_rank_sub_partitions[i][rank][0][0].device,
return_partition_params=True)
all_sub_partitions.append(grad_sub_partitions)
# create map from partition -> params in that partition
for comm_idx, part in enumerate(grad_sub_partitions):
partition_param_map[part] = (partition_params[comm_idx],
param_offsets[comm_idx])
for comm_idx, params in enumerate(partition_params):
for pidx, p in enumerate(params):
# store the parameters we care about locally
if rank == local_rank:
my_params.add(p)
# map from param -> partitions
if p in param_partition_map:
param_partition_map[p].append(grad_sub_partitions[comm_idx])
else:
param_partition_map[p] = [grad_sub_partitions[comm_idx]]
assert len(grad_sub_partitions) == num_comm_intervals
if not postscale_gradients:
raise NotImplementedError("pre-scale_gradients is not implemented")
all_comm_partitions = []
for comm_idx in range(num_comm_intervals):
single_comm_all_partitions = []
for rank in range(world_size):
single_comm_all_partitions.append(all_sub_partitions[rank][comm_idx])
dist.reduce_scatter(output=single_comm_all_partitions[local_rank],
input_list=single_comm_all_partitions,
group=self.dp_process_group)
if gradient_average:
for partition in single_comm_all_partitions:
partition.mul_(gradient_predivide_factor / world_size)
all_comm_partitions.append(single_comm_all_partitions)
for p in my_params:
partitions = param_partition_map[p]
parts = []
for part in partitions:
params, offsets = partition_param_map[part]
found = False
for p_idx, _p in enumerate(params):
if p.__hash__() == _p.__hash__():
found = True
if offsets[p_idx][0] is not None:
my_part = part.narrow(0,
offsets[p_idx][0],
offsets[p_idx][1])
parts.append(my_part)
assert found
if p is not None:
updated_grad = _unflatten_dense_tensors(torch.cat(parts), [p])
p.grad.copy_(updated_grad[0])
def step(self, closure=None):
# First compute norm for all group so we know if there is overflow
self.overflow = self.overflow_checker.check()
prev_scale = self.loss_scale
self._update_scale(self.overflow)
if self.overflow:
self.zero_grad()
if self.verbose:
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
self.loss_scale))
return self.overflow
norm_groups = []
local_sub_partitions_grad_groups = []
partition_id = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups):
#TODO RS: update get grad norm to support sub partitions
norm_groups.append(get_grad_norm(group, mpu=self.mpu))
#RS: update free grads w.r.t. sub partitions
#free gradients for all the parameters that are not updated by this process
self.free_grad_in_param_list(self.params_not_local[i])
#create flat gradients for parameters updated by this process
#tensor_list, first_offset, partition_size, dtype
#single_grad_partition = self.get_flat_partition(
# tensor_list=self.params_in_partition[i],
# first_offset=self.first_offset[i],
# partition_size=self.partition_size[i],
# dtype=self.single_partition_of_fp32_groups[i].dtype
#)
#TODO RS: can we safely use dtype of the first sub-partition? i think so
local_grad_sub_partitions = self.get_flat_sub_partitions(
comm_tensor_list=self.params_in_rank_sub_partitions[i][partition_id],
comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i]
[partition_id],
sub_partition_size=self.sub_partition_sizes[i],
dtype=self.local_sub_partitions_of_fp32_groups[i][0].dtype,
num_comm_intervals=self.num_comm_intervals_per_group[i],
default_device=self.local_sub_partitions_of_fp32_groups[i][0].device)
#RS: update all our local params with sub-partition grads
#print("self.local_sub_partitions_of_fp32_groups[i]={}, local_grad_sub_partitions={}".format(len(self.local_sub_partitions_of_fp32_groups[i]), len(local_grad_sub_partitions)))
for idx, sub_partition_param in enumerate(self.local_sub_partitions_of_fp32_groups[i]):
sub_partition_param.grad = local_grad_sub_partitions[idx]
#self.single_partition_of_fp32_groups[i].grad = single_grad_partition
#RS: update free grads for sub-partitions
#release all the gradient since we have already created a necessary copy in dp_grad_partition
self.free_grad_in_param_list(
self.params_in_rank_sub_partitions[i][partition_id])
local_sub_partitions_grad_groups.append(local_grad_sub_partitions)
#RS: update unscale/clip with sub partitions
self.unscale_and_clip_grads(local_sub_partitions_grad_groups, norm_groups)
self.optimizer.step()
#RS: clear our sub partition grads
#get rid of the fp32 gradients. Not needed anymore
for group in self.local_sub_partitions_of_fp32_groups:
for idx, sub_partition_param in enumerate(group):
sub_partition_param.grad = None
#group.grad = None
#NOTE RS: removed norm_groups outer loop from original code, i don't think it's needed
#RS: copy all sub-partition fp32 data to fp16 sub partitions
# copy fp32 param data to fp16 partitions w.r.t. our local rank
for fp16_all_sub_partitions, fp32_local_sub_partitions in zip(self.parallel_sub_partitioned_fp16_groups, self.local_sub_partitions_of_fp32_groups):
for local_sub_partition_param_fp16, local_sub_partition_param_fp32 in zip(fp16_all_sub_partitions[partition_id], fp32_local_sub_partitions):
local_sub_partition_param_fp16.data.copy_(
local_sub_partition_param_fp32.data)
#RS: all_gather/broadcast sub-partitions in separate comm calls
#gather the updated weights from everyone
for fp16_all_sub_partitions in self.parallel_comm_sub_partitioned_fp16_groups:
for comm_id, sub_partitions in enumerate(fp16_all_sub_partitions):
dist.all_gather(sub_partitions,
sub_partitions[partition_id],
group=self.dp_process_group)
# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
return self.overflow
def unscale_and_clip_grads(self, grad_groups_flat, norm_groups):
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)
# compute combined scale factor for this group
combined_scale = self.loss_scale
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.loss_scale
for grad in grad_groups_flat:
if isinstance(grad, list):
sub_partitions = grad
for g in sub_partitions:
g.data.mul_(1. / combined_scale)
else:
grad.data.mul_(1. / combined_scale)
def backward(self, loss, retain_graph=False):
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict[
'local_sub_partitions_of_fp32_groups'] = self.local_sub_partitions_of_fp32_groups
return state_dict
def load_state_dict(self, state_dict, load_optimizer_states=True):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
for curr_group, saved_group in zip(self.local_sub_partitions_of_fp32_groups, state_dict['local_sub_partitions_of_fp32_groups']):
for curr_param, saved_param in zip(curr_group, saved_group):
curr_param.data.copy_(saved_param.data)
import torch
import torch.distributed as dist
def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
data_parallel_size = int(dist.get_world_size())
if parameter_parallel_size is None:
parameter_parallel_size = int(data_parallel_size)
print(data_parallel_size, parameter_parallel_size)
assert data_parallel_size % parameter_parallel_size == 0, \
'world size should be divisible by parameter parallel size'
rank = dist.get_rank()
my_group = None
for i in range(dist.get_world_size() // parameter_parallel_size):
ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
my_group = group
return my_group
def pprint(msg):
if not dist.is_initialized() or dist.get_rank() == 0:
print(msg)
......@@ -29,6 +29,14 @@ collections:
tutorials:
output: true
permalink: /:collection/:path/
order:
- getting-started.md
- azure.md
- cifar-10.md
- bert-pretraining.md
- megatron.md
- 1Cycle.md
- lrrt.md
defaults:
- scope:
......
......@@ -18,7 +18,7 @@ lnav:
children:
- title: "Installation"
url: /getting-started/#installation
- title: "Writing Models"
- title: "Writing models"
url: /getting-started/#writing-deepspeed-models
- title: "Training"
url: /getting-started/#training
......@@ -37,19 +37,25 @@ lnav:
url: /docs/config-json/#communication-options
- title: "FP16"
url: /docs/config-json/#fp16-training-options
- title: "ZeRO optimizations"
url: /docs/config-json/#zero-optimizations-for-fp16-training
- title: "Logging"
url: /docs/config-json/#logging
- title: "Activation checkpointing"
url: /docs/config-json/#activation-checkpointing
- title: "Tutorials"
url: /tutorials/
children:
- title: "Getting Started on Azure"
- title: "Getting started"
url: /getting-started/
- title: "Getting started on Azure"
url: /tutorials/azure/
- title: "CIFAR-10"
url: /tutorials/cifar-10/
- title: "Megatron-LM GPT2"
url: /tutorials/megatron/
- title: "BERT Pre-training"
url: /tutorials/bert-pretraining/
- title: "Megatron-LM GPT2"
url: /tutorials/megatron/
- title: "1-Cycle Schedule"
url: /tutorials/1Cycle/
- title: "Learning Rate Range Test"
......
......@@ -12,12 +12,6 @@ layout: archive
{% endif %}
<h2>Features Coming Soon</h2>
{% assign soon = posts | where: "sneak_preview", "true" %}
{% for post in soon %}
{% include archive-single.html %}
{% endfor %}
<h2>{{ site.data.ui-text[site.locale].recent_posts | default: "Recent Posts" }}</h2>
{% assign news = posts | where: "sneak_preview", "false" %}
{% for post in news %}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment