Unverified Commit 41db1c2f authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files
parent 79093d74
#include "cpu_adam.h"
#include <cuda_runtime_api.h>
#include <math.h>
#include <omp.h>
#include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#include "custom_cuda_layers.h"
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
// C++ interface
void Adam_Optimizer::Step(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params)
{
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float bias_correction1 = 1 - _betta1_t;
float bias_correction2 = 1 / sqrt(1 - _betta2_t);
float step_size = -1 * _alpha / bias_correction1;
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
AVX_Data weight_decay4;
if (_weight_decay > 0) weight_decay4.data = SIMD_SET(_weight_decay);
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
AVX_Data grad_4;
grad_4.data = SIMD_LOAD(grads + i);
AVX_Data momentum_4;
momentum_4.data = SIMD_LOAD(_exp_avg + i);
AVX_Data variance_4;
variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
AVX_Data param_4;
param_4.data = SIMD_LOAD(_params + i);
if (_weight_decay > 0)
grad_4.data = SIMD_FMA(param_4.data, weight_decay4.data, grad_4.data);
momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
momentum_4.data = SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);
variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data);
grad_4.data = SIMD_MUL(grad_4.data, grad_4.data);
variance_4.data = SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data);
grad_4.data = SIMD_SQRT(variance_4.data);
grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
SIMD_STORE(_params + i, param_4.data);
if (dev_params) SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4.data);
SIMD_STORE(_exp_avg + i, momentum_4.data);
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + t,
copy_size,
Context::Instance().GetCurrentStream());
_buf_index = !_buf_index;
}
}
#endif
if (_param_size > rounded_size) {
#pragma omp parallel for
for (size_t k = rounded_size; k < _param_size; k++) {
float grad = grads[k];
float param = _params[k];
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0) grad = param * _weight_decay + grad;
momentum *= momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * bias_correction2 + _eps;
grad = momentum / grad;
param = grad * step_size + param;
if (dev_params) _doubled_buffer[_buf_index][k - rounded_size] = (__half)param;
_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + rounded_size,
(_param_size - rounded_size),
Context::Instance().GetCurrentStream());
}
}
}
void Adam_Optimizer::Step_4(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
float bias_correction1 = 1 - _betta1_t;
float bias_correction2 = 1 / sqrt(1 - _betta2_t);
// AVX_Data bias_correction1_4 = SIMD_SET(bias_correction1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 2));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) {
AVX_Data grad_4[4];
grad_4[0].data = SIMD_LOAD(grads + i);
grad_4[1].data = SIMD_LOAD(grads + i + SIMD_WIDTH);
grad_4[2].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 1));
grad_4[3].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 3);
AVX_Data momentum_4[4];
momentum_4[0].data = SIMD_LOAD(_exp_avg + i);
momentum_4[1].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH);
momentum_4[2].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 1));
momentum_4[3].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 3);
AVX_Data variance_4[4];
variance_4[0].data = SIMD_LOAD(_exp_avg_sq + i);
variance_4[1].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH);
variance_4[2].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 1));
variance_4[3].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 3);
AVX_Data param_4[4];
param_4[0].data = SIMD_LOAD(_params + i);
param_4[1].data = SIMD_LOAD(_params + i + SIMD_WIDTH);
param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1));
param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3);
if (_weight_decay > 0) {
AVX_Data weight_decay4;
weight_decay4.data = SIMD_SET(_weight_decay);
grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data);
grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data);
grad_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, grad_4[3].data);
}
momentum_4[0].data = SIMD_MUL(momentum_4[0].data, betta1_4.data);
momentum_4[0].data = SIMD_FMA(grad_4[0].data, betta1_minus1_4.data, momentum_4[0].data);
momentum_4[1].data = SIMD_MUL(momentum_4[1].data, betta1_4.data);
momentum_4[1].data = SIMD_FMA(grad_4[1].data, betta1_minus1_4.data, momentum_4[1].data);
momentum_4[2].data = SIMD_MUL(momentum_4[2].data, betta1_4.data);
momentum_4[2].data = SIMD_FMA(grad_4[2].data, betta1_minus1_4.data, momentum_4[2].data);
momentum_4[3].data = SIMD_MUL(momentum_4[3].data, betta1_4.data);
momentum_4[3].data = SIMD_FMA(grad_4[3].data, betta1_minus1_4.data, momentum_4[3].data);
variance_4[0].data = SIMD_MUL(variance_4[0].data, betta2_4.data);
variance_4[1].data = SIMD_MUL(variance_4[1].data, betta2_4.data);
variance_4[2].data = SIMD_MUL(variance_4[2].data, betta2_4.data);
variance_4[3].data = SIMD_MUL(variance_4[3].data, betta2_4.data);
grad_4[0].data = SIMD_MUL(grad_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_MUL(grad_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_MUL(grad_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_MUL(grad_4[3].data, grad_4[3].data);
variance_4[0].data = SIMD_FMA(grad_4[0].data, betta2_minus1_4.data, variance_4[0].data);
variance_4[1].data = SIMD_FMA(grad_4[1].data, betta2_minus1_4.data, variance_4[1].data);
variance_4[2].data = SIMD_FMA(grad_4[2].data, betta2_minus1_4.data, variance_4[2].data);
variance_4[3].data = SIMD_FMA(grad_4[3].data, betta2_minus1_4.data, variance_4[3].data);
grad_4[0].data = SIMD_SQRT(variance_4[0].data);
grad_4[1].data = SIMD_SQRT(variance_4[1].data);
grad_4[2].data = SIMD_SQRT(variance_4[2].data);
grad_4[3].data = SIMD_SQRT(variance_4[3].data);
grad_4[0].data = SIMD_FMA(grad_4[0].data, bias2_sqrt.data, eps_4.data);
grad_4[1].data = SIMD_FMA(grad_4[1].data, bias2_sqrt.data, eps_4.data);
grad_4[2].data = SIMD_FMA(grad_4[2].data, bias2_sqrt.data, eps_4.data);
grad_4[3].data = SIMD_FMA(grad_4[3].data, bias2_sqrt.data, eps_4.data);
grad_4[0].data = SIMD_DIV(momentum_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_DIV(momentum_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data);
param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(grad_4[3].data, step_size_4.data, param_4[3].data);
SIMD_STORE(_params + i, param_4[0].data);
SIMD_STORE(_params + i + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_params + i + (SIMD_WIDTH << 1), param_4[2].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 3, param_4[3].data);
if (dev_params) {
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4[0].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1),
param_4[2].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3].data);
}
SIMD_STORE(_exp_avg + i, momentum_4[0].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH, momentum_4[1].data);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 1), momentum_4[2].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 3, momentum_4[3].data);
SIMD_STORE(_exp_avg_sq + i, variance_4[0].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH, variance_4[1].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 1), variance_4[2].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 3, variance_4[3].data);
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + t,
copy_size,
Context::Instance().GetCurrentStream());
_buf_index = !_buf_index;
}
}
#endif
if (_param_size > rounded_size)
Step((_params + rounded_size),
(grads + rounded_size),
(_exp_avg + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params));
}
int create_adam_optimizer(int optimizer_id,
float alpha = 1e-3,
float betta1 = 0.9,
float betta2 = 0.999,
float eps = 1e-8,
float weight_decay = 0)
{
auto opt = std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps, weight_decay);
s_optimizers[optimizer_id] = opt;
#if defined(__AVX512__)
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with AVX512 arithmetic capability." << std::endl;
#else
#if defined(__AVX256__)
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with AVX2 arithmetic capability." << std::endl;
#else
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with scalar arithmetic capability." << std::endl;
#endif
#endif
return 0;
}
void Adam_Optimizer::Step_8(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
float bias_correction1 = 1 - _betta1_t;
float bias_correction2 = 1 / sqrt(1 - _betta2_t);
// AVX_Data bias_correction1_4 = SIMD_SET(bias_correction1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 3));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) {
AVX_Data grad_4[8];
grad_4[0].data = SIMD_LOAD(grads + i);
grad_4[1].data = SIMD_LOAD(grads + i + SIMD_WIDTH);
grad_4[2].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 1));
grad_4[3].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 3);
grad_4[4].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 2));
grad_4[5].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 5);
grad_4[6].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 6);
grad_4[7].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 7);
AVX_Data momentum_4[8];
momentum_4[0].data = SIMD_LOAD(_exp_avg + i);
momentum_4[1].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH);
momentum_4[2].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 1));
momentum_4[3].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 3);
momentum_4[4].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 2));
momentum_4[5].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 5);
momentum_4[6].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 6);
momentum_4[7].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 7);
AVX_Data variance_4[8];
variance_4[0].data = SIMD_LOAD(_exp_avg_sq + i);
variance_4[1].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH);
variance_4[2].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 1));
variance_4[3].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 3);
variance_4[4].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 2));
variance_4[5].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 5);
variance_4[6].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 6);
variance_4[7].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 7);
AVX_Data param_4[8];
param_4[0].data = SIMD_LOAD(_params + i);
param_4[1].data = SIMD_LOAD(_params + i + SIMD_WIDTH);
param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1));
param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3);
param_4[4].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 2));
param_4[5].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 5);
param_4[6].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 6);
param_4[7].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 7);
if (_weight_decay > 0) {
AVX_Data weight_decay4;
weight_decay4.data = SIMD_SET(_weight_decay);
grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data);
grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data);
grad_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, grad_4[3].data);
grad_4[4].data = SIMD_FMA(param_4[4].data, weight_decay4.data, grad_4[4].data);
grad_4[5].data = SIMD_FMA(param_4[5].data, weight_decay4.data, grad_4[5].data);
grad_4[6].data = SIMD_FMA(param_4[6].data, weight_decay4.data, grad_4[6].data);
grad_4[7].data = SIMD_FMA(param_4[7].data, weight_decay4.data, grad_4[7].data);
}
momentum_4[0].data = SIMD_MUL(momentum_4[0].data, betta1_4.data);
momentum_4[0].data = SIMD_FMA(grad_4[0].data, betta1_minus1_4.data, momentum_4[0].data);
momentum_4[1].data = SIMD_MUL(momentum_4[1].data, betta1_4.data);
momentum_4[1].data = SIMD_FMA(grad_4[1].data, betta1_minus1_4.data, momentum_4[1].data);
momentum_4[2].data = SIMD_MUL(momentum_4[2].data, betta1_4.data);
momentum_4[2].data = SIMD_FMA(grad_4[2].data, betta1_minus1_4.data, momentum_4[2].data);
momentum_4[3].data = SIMD_MUL(momentum_4[3].data, betta1_4.data);
momentum_4[3].data = SIMD_FMA(grad_4[3].data, betta1_minus1_4.data, momentum_4[3].data);
momentum_4[4].data = SIMD_MUL(momentum_4[4].data, betta1_4.data);
momentum_4[4].data = SIMD_FMA(grad_4[4].data, betta1_minus1_4.data, momentum_4[4].data);
momentum_4[5].data = SIMD_MUL(momentum_4[5].data, betta1_4.data);
momentum_4[5].data = SIMD_FMA(grad_4[5].data, betta1_minus1_4.data, momentum_4[5].data);
momentum_4[6].data = SIMD_MUL(momentum_4[6].data, betta1_4.data);
momentum_4[6].data = SIMD_FMA(grad_4[6].data, betta1_minus1_4.data, momentum_4[6].data);
momentum_4[7].data = SIMD_MUL(momentum_4[7].data, betta1_4.data);
momentum_4[7].data = SIMD_FMA(grad_4[7].data, betta1_minus1_4.data, momentum_4[7].data);
variance_4[0].data = SIMD_MUL(variance_4[0].data, betta2_4.data);
variance_4[1].data = SIMD_MUL(variance_4[1].data, betta2_4.data);
variance_4[2].data = SIMD_MUL(variance_4[2].data, betta2_4.data);
variance_4[3].data = SIMD_MUL(variance_4[3].data, betta2_4.data);
variance_4[4].data = SIMD_MUL(variance_4[4].data, betta2_4.data);
variance_4[5].data = SIMD_MUL(variance_4[5].data, betta2_4.data);
variance_4[6].data = SIMD_MUL(variance_4[6].data, betta2_4.data);
variance_4[7].data = SIMD_MUL(variance_4[7].data, betta2_4.data);
grad_4[0].data = SIMD_MUL(grad_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_MUL(grad_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_MUL(grad_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_MUL(grad_4[3].data, grad_4[3].data);
grad_4[4].data = SIMD_MUL(grad_4[4].data, grad_4[4].data);
grad_4[5].data = SIMD_MUL(grad_4[5].data, grad_4[5].data);
grad_4[6].data = SIMD_MUL(grad_4[6].data, grad_4[6].data);
grad_4[7].data = SIMD_MUL(grad_4[7].data, grad_4[7].data);
variance_4[0].data = SIMD_FMA(grad_4[0].data, betta2_minus1_4.data, variance_4[0].data);
variance_4[1].data = SIMD_FMA(grad_4[1].data, betta2_minus1_4.data, variance_4[1].data);
variance_4[2].data = SIMD_FMA(grad_4[2].data, betta2_minus1_4.data, variance_4[2].data);
variance_4[3].data = SIMD_FMA(grad_4[3].data, betta2_minus1_4.data, variance_4[3].data);
variance_4[4].data = SIMD_FMA(grad_4[4].data, betta2_minus1_4.data, variance_4[4].data);
variance_4[5].data = SIMD_FMA(grad_4[5].data, betta2_minus1_4.data, variance_4[5].data);
variance_4[6].data = SIMD_FMA(grad_4[6].data, betta2_minus1_4.data, variance_4[6].data);
variance_4[7].data = SIMD_FMA(grad_4[7].data, betta2_minus1_4.data, variance_4[7].data);
grad_4[0].data = SIMD_SQRT(variance_4[0].data);
grad_4[1].data = SIMD_SQRT(variance_4[1].data);
grad_4[2].data = SIMD_SQRT(variance_4[2].data);
grad_4[3].data = SIMD_SQRT(variance_4[3].data);
grad_4[4].data = SIMD_SQRT(variance_4[4].data);
grad_4[5].data = SIMD_SQRT(variance_4[5].data);
grad_4[6].data = SIMD_SQRT(variance_4[6].data);
grad_4[7].data = SIMD_SQRT(variance_4[7].data);
grad_4[0].data = SIMD_FMA(grad_4[0].data, bias2_sqrt.data, eps_4.data);
grad_4[1].data = SIMD_FMA(grad_4[1].data, bias2_sqrt.data, eps_4.data);
grad_4[2].data = SIMD_FMA(grad_4[2].data, bias2_sqrt.data, eps_4.data);
grad_4[3].data = SIMD_FMA(grad_4[3].data, bias2_sqrt.data, eps_4.data);
grad_4[4].data = SIMD_FMA(grad_4[4].data, bias2_sqrt.data, eps_4.data);
grad_4[5].data = SIMD_FMA(grad_4[5].data, bias2_sqrt.data, eps_4.data);
grad_4[6].data = SIMD_FMA(grad_4[6].data, bias2_sqrt.data, eps_4.data);
grad_4[7].data = SIMD_FMA(grad_4[7].data, bias2_sqrt.data, eps_4.data);
grad_4[0].data = SIMD_DIV(momentum_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_DIV(momentum_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data);
grad_4[4].data = SIMD_DIV(momentum_4[4].data, grad_4[4].data);
grad_4[5].data = SIMD_DIV(momentum_4[5].data, grad_4[5].data);
grad_4[6].data = SIMD_DIV(momentum_4[6].data, grad_4[6].data);
grad_4[7].data = SIMD_DIV(momentum_4[7].data, grad_4[7].data);
param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(grad_4[3].data, step_size_4.data, param_4[3].data);
param_4[4].data = SIMD_FMA(grad_4[4].data, step_size_4.data, param_4[4].data);
param_4[5].data = SIMD_FMA(grad_4[5].data, step_size_4.data, param_4[5].data);
param_4[6].data = SIMD_FMA(grad_4[6].data, step_size_4.data, param_4[6].data);
param_4[7].data = SIMD_FMA(grad_4[7].data, step_size_4.data, param_4[7].data);
SIMD_STORE(_params + i, param_4[0].data);
SIMD_STORE(_params + i + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_params + i + (SIMD_WIDTH << 1), param_4[2].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 3, param_4[3].data);
SIMD_STORE(_params + i + (SIMD_WIDTH << 2), param_4[4].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 5, param_4[5].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 6, param_4[6].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 7, param_4[7].data);
if (dev_params) {
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4[0].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1),
param_4[2].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 2),
param_4[4].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 5, param_4[5].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 6, param_4[6].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 7, param_4[7].data);
}
SIMD_STORE(_exp_avg + i, momentum_4[0].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH, momentum_4[1].data);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 1), momentum_4[2].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 3, momentum_4[3].data);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 2), momentum_4[4].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 5, momentum_4[5].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 6, momentum_4[6].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 7, momentum_4[7].data);
SIMD_STORE(_exp_avg_sq + i, variance_4[0].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH, variance_4[1].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 1), variance_4[2].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 3, variance_4[3].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 2), variance_4[4].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 5, variance_4[5].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 6, variance_4[6].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 7, variance_4[7].data);
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + t,
copy_size,
Context::Instance().GetCurrentStream());
_buf_index = !_buf_index;
}
}
#endif
if (_param_size > rounded_size)
Step_4((_params + rounded_size),
(grads + rounded_size),
(_exp_avg + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params));
}
int ds_adam_step(int optimizer_id,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq)
{
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep();
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0));
return 0;
}
int ds_adam_step_plus_copy(int optimizer_id,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq,
torch::Tensor& gpu_params)
{
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
auto grads_c = grads.contiguous();
float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
__half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep();
opt->Step_8(
params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr);
return 0;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
m.def("adam_update_copy",
&ds_adam_step_plus_copy,
"DeepSpeed CPU Adam update and param copy (C++)");
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
}
#include "custom_cuda_layers.h"
__global__ void param_update_kernel(const float* input, __half* output, int size)
{
const float4* input_cast = reinterpret_cast<const float4*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
int id = blockIdx.x * blockDim.x + threadIdx.x;
if (id < size) {
float4 data = input_cast[id];
float2 cast_data;
__half* output_h = reinterpret_cast<__half*>(&cast_data);
output_h[0] = (__half)data.x;
output_h[1] = (__half)data.y;
output_h[2] = (__half)data.z;
output_h[3] = (__half)data.w;
output_cast[id] = cast_data;
}
}
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream)
{
int threads = 512;
size /= 4;
dim3 grid_dim((size - 1) / threads + 1);
dim3 block_dim(threads);
param_update_kernel<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
}
#pragma once
#include <cpuid.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
#include <x86intrin.h>
#include <cassert>
#include "context.h"
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#define TILE (1024 * 1024 * 1024)
#if defined(__AVX512__)
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
#define SIMD_SET(x) _mm512_set1_ps(x)
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_WIDTH 16
#else
#if defined(__AVX256__)
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm256_loadu_ps(x)
#define SIMD_SET(x) _mm256_set1_ps(x)
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_WIDTH 8
#endif
#endif
class Adam_Optimizer {
public:
Adam_Optimizer(float alpha = 1e-3,
float betta1 = 0.9,
float betta2 = 0.999,
float eps = 1e-8,
float weight_decay = 0)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_buf_index(false)
{
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
}
~Adam_Optimizer()
{
cudaFreeHost(_doubled_buffer[0]);
cudaFreeHost(_doubled_buffer[1]);
}
void Step(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t param_size,
__half* dev_param = nullptr);
void Step_4(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sa,
size_t param_size,
__half* dev_param = nullptr);
void Step_8(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params = nullptr);
inline void IncrementStep()
{
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
private:
#if defined(__AVX512__) or defined(__AVX256__)
union AVX_Data {
#if defined(__AVX512__)
__m512 data;
#else
__m256 data;
#endif
// float data_f[16];
};
#endif
float _alpha;
float _betta1;
float _betta2;
float _eps;
float _weight_decay;
float _betta1_t;
float _betta2_t;
float* _doubled_buffer[2];
bool _buf_index;
};
...@@ -264,3 +264,5 @@ void launch_fuse_transpose_bias_kernel(const T* inp, ...@@ -264,3 +264,5 @@ void launch_fuse_transpose_bias_kernel(const T* inp,
int rows, int rows,
int cols, int cols,
cudaStream_t stream); cudaStream_t stream);
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream);
...@@ -4,16 +4,18 @@ Copyright 2020 The Microsoft DeepSpeed Team ...@@ -4,16 +4,18 @@ Copyright 2020 The Microsoft DeepSpeed Team
import sys import sys
import types import types
from deepspeed.runtime.engine import DeepSpeedEngine from . import ops
from deepspeed.runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from deepspeed.runtime.lr_schedules import add_tuning_arguments from .runtime.engine import DeepSpeedEngine
from deepspeed.runtime.config import DeepSpeedConfig from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_ADAM
from deepspeed.runtime.activation_checkpointing import checkpointing from .runtime.lr_schedules import add_tuning_arguments
from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from .runtime.config import DeepSpeedConfig
from deepspeed.utils import logger from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .utils import logger
try: try:
from deepspeed.git_version_info import version, git_hash, git_branch from .git_version_info import version, git_hash, git_branch
except ImportError: except ImportError:
version = "0.0.0+unknown" version = "0.0.0+unknown"
git_hash = None git_hash = None
......
from ..git_version_info import installed_ops as __installed_ops__
from . import lamb
from . import transformer
if __installed_ops__['sparse-attn']:
from . import sparse_attention
if __installed_ops__['cpu-adam']:
from . import adam
from .cpu_adam import DeepSpeedCPUAdam
import math
import torch
import importlib
ds_opt_adam = None
class DeepSpeedCPUAdam(torch.optim.Optimizer):
optimizer_id = 0
def __init__(self,
model_params,
lr=1e-3,
betas=(0.9,
0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False):
default_args = dict(lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad)
super(DeepSpeedCPUAdam, self).__init__(model_params, default_args)
self.opt_id = DeepSpeedCPUAdam.optimizer_id
DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1
global ds_opt_adam
ds_opt_adam = importlib.import_module('deepspeed.ops.adam.cpu_adam_op')
ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay)
def __setstate__(self, state):
super(DeepSpeedCPUAdam, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None, fp16_param_groups=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group_id, group in enumerate(self.param_groups):
for param_id, p in enumerate(group['params']):
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
# State initialization
if len(state) == 0:
print(f'group {group_id} param {param_id} = {p.numel()}')
state['step'] = 0
# gradient momentums
state['exp_avg'] = torch.zeros_like(p.data, device='cpu')
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data, device='cpu')
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
state['step'] += 1
if fp16_param_groups is not None:
p_fp16 = fp16_param_groups[group_id][param_id]
ds_opt_adam.adam_update_copy(self.opt_id,
p.data,
grad,
exp_avg,
exp_avg_sq,
p_fp16)
else:
ds_opt_adam.adam_update(self.opt_id,
p.data,
grad,
exp_avg,
exp_avg_sq)
return loss
import torch
from torch.autograd import Variable
import collections
def async_migrate_to(obj, dev, main_stream=None):
if torch.is_tensor(obj):
obj = Variable(obj)
if isinstance(obj, Variable):
v = obj.cuda(dev, async=True)
if main_stream is not None:
v.data.record_stream(main_stream)
return v
elif isinstance(obj, collections.Mapping):
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
elif isinstance(obj, collections.Sequence):
return [async_copy_to(o, dev, main_stream) for o in obj]
else:
return obj
def async_copy_to(obj, dev, main_stream=None):
if torch.is_tensor(obj):
obj = Variable(obj)
if isinstance(obj, Variable):
target = torch.empty_like(obj, device=dev).copy_(obj)
if main_stream is not None:
target.data.record_stream(main_stream)
return target
elif isinstance(obj, collections.Mapping):
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
elif isinstance(obj, collections.Sequence):
return [async_copy_to(o, dev, main_stream) for o in obj]
...@@ -10,14 +10,21 @@ from deepspeed.runtime.constants import * ...@@ -10,14 +10,21 @@ from deepspeed.runtime.constants import *
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
from deepspeed.runtime.config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys from deepspeed.runtime.config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from deepspeed.runtime.zero.constants import *
from deepspeed.runtime.activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig from deepspeed.runtime.activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
from deepspeed.utils import logger from deepspeed.utils import logger
TENSOR_CORE_ALIGN_SIZE = 8 TENSOR_CORE_ALIGN_SIZE = 8
ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
ADAM_OPTIMIZER = 'adam' ADAM_OPTIMIZER = 'adam'
LAMB_OPTIMIZER = 'lamb' LAMB_OPTIMIZER = 'lamb'
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER] ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
DEEPSPEED_ADAM = 'deepspeed_adam'
DEEPSPEED_OPTIMIZERS = [
ADAM_OPTIMIZER,
LAMB_OPTIMIZER,
ONEBIT_ADAM_OPTIMIZER,
DEEPSPEED_ADAM
]
def get_amp_enabled(param_dict): def get_amp_enabled(param_dict):
...@@ -111,22 +118,9 @@ def get_zero_optimization(param_dict): ...@@ -111,22 +118,9 @@ def get_zero_optimization(param_dict):
def get_zero_reduce_scatter(param_dict): def get_zero_reduce_scatter(param_dict):
return get_scalar_param(param_dict, ZERO_REDUCE_SCATTER, ZERO_REDUCE_SCATTER_DEFAULT)
def get_zero_max_elements_per_comm(param_dict):
return get_scalar_param(param_dict, return get_scalar_param(param_dict,
ZERO_MAX_ELEMENTS_PER_COMM, ZERO_OPTIMIZATION_REDUCE_SCATTER,
ZERO_MAX_ELEMENTS_PER_COMM_DEFAULT) ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT)
def get_allgather_size(param_dict):
return get_scalar_param(param_dict,
ALLGATHER_SIZE,
ALLGATHER_SIZE_DEFAULT) if get_scalar_param(
param_dict,
ALLGATHER_SIZE,
ALLGATHER_SIZE_DEFAULT) > 0 else ALLGATHER_SIZE_DEFAULT
def get_allreduce_always_fp32(param_dict): def get_allreduce_always_fp32(param_dict):
...@@ -493,8 +487,6 @@ class DeepSpeedConfig(object): ...@@ -493,8 +487,6 @@ class DeepSpeedConfig(object):
self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict) self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict)
self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict) self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)
self.allgather_size = get_allgather_size(param_dict)
self.zero_config = DeepSpeedZeroConfig(param_dict) self.zero_config = DeepSpeedZeroConfig(param_dict)
self.zero_optimization_stage = self.zero_config.stage self.zero_optimization_stage = self.zero_config.stage
self.zero_enabled = self.zero_optimization_stage > 0 self.zero_enabled = self.zero_optimization_stage > 0
...@@ -628,15 +620,18 @@ class DeepSpeedConfig(object): ...@@ -628,15 +620,18 @@ class DeepSpeedConfig(object):
':')))) ':'))))
def _do_error_check(self): def _do_error_check(self):
if self.zero_enabled:
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU) assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
assert self.gradient_accumulation_steps, 'DeepSpeedConfig: {} is not defined'.format( assert self.gradient_accumulation_steps, "DeepSpeedConfig: {} is not defined".format(
GRADIENT_ACCUMULATION_STEPS) GRADIENT_ACCUMULATION_STEPS)
if self.zero_enabled:
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
if self.zero_config.cpu_offload is True:
assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS)
#assert self.gradient_accumulation_steps == 1, "DeepSpeedConfig: {}is not supported for {}".format(GRADIENT_ACCUMULATION_STEPS, ZERO_OPTIMIZATION_CPU_OFFLOAD)
def _do_warning_check(self): def _do_warning_check(self):
fp16_enabled = self.fp16_enabled or self.zero_enabled fp16_enabled = self.fp16_enabled or self.zero_enabled
......
...@@ -183,35 +183,6 @@ Gradient clipping should be enabled as: ...@@ -183,35 +183,6 @@ Gradient clipping should be enabled as:
GRADIENT_CLIPPING = 'gradient_clipping' GRADIENT_CLIPPING = 'gradient_clipping'
GRADIENT_CLIPPING_DEFAULT = 0. GRADIENT_CLIPPING_DEFAULT = 0.
#########################################
# ZeRO optimization
#########################################
# ZeRO optimization. By default, this optimization is not enabled.
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
ZERO_FORMAT = '''
ZeRO optimization should be enabled as:
"session_params": {
"zero_optimization": [0|1|2],
"zero_all_gather_size": 200
}
'''
ZERO_OPTIMIZATION = 'zero_optimization'
ZERO_OPTIMIZATION_DEFAULT = 0
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
ZERO_OPTIMIZATION_GRADIENTS = 2
ZERO_OPTIMIZATION_WEIGHTS = 3
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
ZERO_REDUCE_SCATTER = "zero_reduce_scatter"
ZERO_REDUCE_SCATTER_DEFAULT = True
ZERO_MAX_ELEMENTS_PER_COMM = "zero_max_elements_per_comm"
ZERO_MAX_ELEMENTS_PER_COMM_DEFAULT = 5e8
ALLGATHER_SIZE = 'allgather_size'
ALLGATHER_SIZE_DEFAULT = 500000000
######################################### #########################################
# FP32 AllReduce # FP32 AllReduce
######################################### #########################################
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import warnings import warnings
import torch.distributed as dist import torch.distributed as dist
import apex
from apex import amp from apex import amp
from torch.nn.modules import Module from torch.nn.modules import Module
from torch.distributed.distributed_c10d import _get_global_rank from torch.distributed.distributed_c10d import _get_global_rank
...@@ -14,20 +15,20 @@ from tensorboardX import SummaryWriter ...@@ -14,20 +15,20 @@ from tensorboardX import SummaryWriter
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1 from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, \ from deepspeed.runtime.config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, DEEPSPEED_OPTIMIZERS ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, DEEPSPEED_ADAM, DEEPSPEED_OPTIMIZERS
from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \ from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
TORCH_DISTRIBUTED_DEFAULT_PORT, \ TORCH_DISTRIBUTED_DEFAULT_PORT
from deepspeed.runtime.zero.constants import \
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.runtime.csr_tensor import CSRTensor from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.utils import logger from deepspeed.utils import logger
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
...@@ -105,7 +106,6 @@ class DeepSpeedEngine(Module): ...@@ -105,7 +106,6 @@ class DeepSpeedEngine(Module):
collate_fn=None, collate_fn=None,
config_params=None): config_params=None):
super(DeepSpeedEngine, self).__init__() super(DeepSpeedEngine, self).__init__()
self.client_optimizer = optimizer self.client_optimizer = optimizer
self.client_model_parameters = model_parameters self.client_model_parameters = model_parameters
self.client_lr_scheduler = lr_scheduler self.client_lr_scheduler = lr_scheduler
...@@ -266,7 +266,7 @@ class DeepSpeedEngine(Module): ...@@ -266,7 +266,7 @@ class DeepSpeedEngine(Module):
return self._config.train_micro_batch_size_per_gpu return self._config.train_micro_batch_size_per_gpu
def optimizer_name(self): def optimizer_name(self):
return self._config.optimizer_name return self.client_optimizer.__class__.__name__ if self.client_optimizer else self._config.optimizer_name
def optimizer_params(self): def optimizer_params(self):
return self._config.optimizer_params return self._config.optimizer_params
...@@ -292,6 +292,9 @@ class DeepSpeedEngine(Module): ...@@ -292,6 +292,9 @@ class DeepSpeedEngine(Module):
def zero_overlap_comm(self): def zero_overlap_comm(self):
return self._config.zero_config.overlap_comm return self._config.zero_config.overlap_comm
def zero_cpu_offload(self):
return self._config.zero_config.cpu_offload
def zero_optimization_stage(self): def zero_optimization_stage(self):
return self._config.zero_optimization_stage return self._config.zero_optimization_stage
...@@ -310,9 +313,6 @@ class DeepSpeedEngine(Module): ...@@ -310,9 +313,6 @@ class DeepSpeedEngine(Module):
def zero_load_from_fp32_weights(self): def zero_load_from_fp32_weights(self):
return self._config.zero_config.load_from_fp32_weights return self._config.zero_config.load_from_fp32_weights
def allgather_size(self):
return self._config.allgather_size
def fp16_enabled(self): def fp16_enabled(self):
return self._config.fp16_enabled return self._config.fp16_enabled
...@@ -491,6 +491,7 @@ class DeepSpeedEngine(Module): ...@@ -491,6 +491,7 @@ class DeepSpeedEngine(Module):
# Configure optimizer # Configure optimizer
def _configure_optimizer(self, client_optimizer, model_parameters): def _configure_optimizer(self, client_optimizer, model_parameters):
if client_optimizer is not None: if client_optimizer is not None:
basic_optimizer = client_optimizer basic_optimizer = client_optimizer
logger.info('Using client Optimizer as basic optimizer') logger.info('Using client Optimizer as basic optimizer')
...@@ -504,13 +505,14 @@ class DeepSpeedEngine(Module): ...@@ -504,13 +505,14 @@ class DeepSpeedEngine(Module):
if self.zero_optimization(): if self.zero_optimization():
assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2" assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
if self.optimizer_name() != ADAM_OPTIMIZER: if not is_zero_supported_optimizer(basic_optimizer):
assert self.zero_allow_untested_optimizer(), \ assert self.zero_allow_untested_optimizer(), \
'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'
logger.warning( logger.warning(
"**** You are using ZeRO with an untested optimizer, proceed with caution *****" "**** You are using ZeRO with an untested optimizer, proceed with caution *****"
) )
self.optimizer = self._configure_zero_optimizer(basic_optimizer) self.optimizer = self._configure_zero_optimizer(basic_optimizer)
elif self.amp_enabled(): elif self.amp_enabled():
assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode" assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode"
...@@ -522,8 +524,8 @@ class DeepSpeedEngine(Module): ...@@ -522,8 +524,8 @@ class DeepSpeedEngine(Module):
self.optimizer = self._configure_fp16_optimizer(basic_optimizer) self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
else: else:
self.optimizer = basic_optimizer self.optimizer = basic_optimizer
logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer))
# logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict())) logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict()))
def _configure_basic_optimizer(self, model_parameters): def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params() optimizer_parameters = self.optimizer_params()
...@@ -533,8 +535,14 @@ class DeepSpeedEngine(Module): ...@@ -533,8 +535,14 @@ class DeepSpeedEngine(Module):
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details" "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
) )
if self.optimizer_name() == ADAM_OPTIMIZER: if self.optimizer_name() == ADAM_OPTIMIZER:
from apex.optimizers.fused_adam import FusedAdam if self.zero_cpu_offload():
optimizer = FusedAdam(model_parameters, **optimizer_parameters) optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
else:
from apex.optimizers.fused_adam import FusedAdam
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == DEEPSPEED_ADAM:
from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer = DeepSpeedCPUAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER: elif self.optimizer_name() == LAMB_OPTIMIZER:
from deepspeed.ops.lamb import FusedLamb from deepspeed.ops.lamb import FusedLamb
optimizer = FusedLamb(model_parameters, **optimizer_parameters) optimizer = FusedLamb(model_parameters, **optimizer_parameters)
...@@ -550,8 +558,9 @@ class DeepSpeedEngine(Module): ...@@ -550,8 +558,9 @@ class DeepSpeedEngine(Module):
initial_dynamic_scale = self.initial_dynamic_scale() initial_dynamic_scale = self.initial_dynamic_scale()
dynamic_loss_args = self.dynamic_loss_scale_args() dynamic_loss_args = self.dynamic_loss_scale_args()
clip_grad = self.gradient_clipping() clip_grad = self.gradient_clipping()
if self.optimizer_name() == ADAM_OPTIMIZER or self.optimizer_name( if isinstance(optimizer,
) == ONEBIT_ADAM_OPTIMIZER: apex.optimizers.FusedAdam) or self.optimizer_name(
) == ONEBIT_ADAM_OPTIMIZER:
if self.dynamic_loss_scale(): if self.dynamic_loss_scale():
logger.info('Creating fp16 optimizer with dynamic loss scale') logger.info('Creating fp16 optimizer with dynamic loss scale')
timers = self.timers if self.wall_clock_breakdown() else None timers = self.timers if self.wall_clock_breakdown() else None
...@@ -616,9 +625,11 @@ class DeepSpeedEngine(Module): ...@@ -616,9 +625,11 @@ class DeepSpeedEngine(Module):
dp_process_group=self.data_parallel_group, dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(), reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(), overlap_comm=self.zero_overlap_comm(),
cpu_offload=self.zero_cpu_offload(),
mpu=self.mpu, mpu=self.mpu,
postscale_gradients=self.postscale_gradients(), postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor()) gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps())
else: else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage)) raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
...@@ -724,7 +735,6 @@ class DeepSpeedEngine(Module): ...@@ -724,7 +735,6 @@ class DeepSpeedEngine(Module):
return loss return loss
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
#Zero stage 2 communicates during non gradient accumulation boundaries as well #Zero stage 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients(): if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue() self.optimizer.overlapping_partition_gradients_reduce_epilogue()
...@@ -780,6 +790,8 @@ class DeepSpeedEngine(Module): ...@@ -780,6 +790,8 @@ class DeepSpeedEngine(Module):
self.timers('backward_inner').start() self.timers('backward_inner').start()
if self.zero_optimization(): if self.zero_optimization():
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary(
)
self.optimizer.backward(loss) self.optimizer.backward(loss)
elif self.amp_enabled(): elif self.amp_enabled():
# AMP requires delaying unscale when inside gradient accumulation boundaries # AMP requires delaying unscale when inside gradient accumulation boundaries
...@@ -854,7 +866,6 @@ class DeepSpeedEngine(Module): ...@@ -854,7 +866,6 @@ class DeepSpeedEngine(Module):
master_params = amp.master_params(self.optimizer) master_params = amp.master_params(self.optimizer)
torch.nn.utils.clip_grad_norm_(parameters=master_params, torch.nn.utils.clip_grad_norm_(parameters=master_params,
max_norm=self.gradient_clipping()) max_norm=self.gradient_clipping())
self.optimizer.step() self.optimizer.step()
#zero grad in basic optimizer could be unreliable and may not exhibit #zero grad in basic optimizer could be unreliable and may not exhibit
...@@ -957,6 +968,9 @@ class DeepSpeedEngine(Module): ...@@ -957,6 +968,9 @@ class DeepSpeedEngine(Module):
def get_lr(self): def get_lr(self):
return self._get_optimizer_param('lr') return self._get_optimizer_param('lr')
def get_type(self):
return self._get_optimizer_param('type')
def get_mom(self): def get_mom(self):
return self._get_optimizer_param('betas') return self._get_optimizer_param('betas')
......
...@@ -5,79 +5,7 @@ Licensed under the MIT license. ...@@ -5,79 +5,7 @@ Licensed under the MIT license.
from deepspeed.runtime.config_utils import get_scalar_param from deepspeed.runtime.config_utils import get_scalar_param
from deepspeed.utils import logger from deepspeed.utils import logger
from deepspeed.runtime.zero.constants import *
#########################################
# ZeRO optimization
#########################################
# ZeRO optimization. By default, this optimization is not enabled.
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
ZERO_FORMAT = '''
ZeRO optimization should be enabled as:
"session_params": {
"zero_optimization": {
"stage": [0|1|2],
"allgather_partitions": [true|false],
"allgather_bucket_size": 500000000,
"reduce_scatter": [true|false],
"contiguous_gradients" : [true|false]
"overlap_comm": [true|false],
"reduce_bucket_size": 500000000
"load_from_fp32_weights": [true|false]
}
}
'''
ZERO_OPTIMIZATION = 'zero_optimization'
ZERO_OPTIMIZATION_DISABLED = 0
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
ZERO_OPTIMIZATION_GRADIENTS = 2
ZERO_OPTIMIZATION_WEIGHTS = 3
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
ZERO_OPTIMIZATION_STAGE = 'stage'
ZERO_OPTIMIZATION_STAGE_1 = 'stage_1'
ZERO_OPTIMIZATION_STAGE_2 = 'stage_2'
ZERO_OPTIMIZATION_STAGE_3 = 'stage_3'
ZERO_OPTIMIZATION_STAGE_DEFAULT = ZERO_OPTIMIZATION_DISABLED
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS = 'allgather_partitions'
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT = True
ZERO_OPTIMIZATION_REDUCE_SCATTER = 'reduce_scatter'
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT = True
ZERO_OPTIMIZATION_OVERLAP_COMM = 'overlap_comm'
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT = False
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS = 'contiguous_gradients'
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT = False
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE = 'reduce_bucket_size'
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE = 'allgather_bucket_size'
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED = 'allgather_size'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS = 'load_from_fp32_weights'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True
ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE:
ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER:
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE:
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT
}
class DeepSpeedZeroConfig(object): class DeepSpeedZeroConfig(object):
...@@ -92,6 +20,7 @@ class DeepSpeedZeroConfig(object): ...@@ -92,6 +20,7 @@ class DeepSpeedZeroConfig(object):
self.allgather_bucket_size = None self.allgather_bucket_size = None
self.overlap_comm = None self.overlap_comm = None
self.load_from_fp32_weights = None self.load_from_fp32_weights = None
self.cpu_offload = None
if ZERO_OPTIMIZATION in param_dict.keys(): if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION] zero_config_dict = param_dict[ZERO_OPTIMIZATION]
...@@ -156,7 +85,12 @@ class DeepSpeedZeroConfig(object): ...@@ -156,7 +85,12 @@ class DeepSpeedZeroConfig(object):
zero_config_dict, zero_config_dict,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT) ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)
self.load_from_fp32_weights = get_scalar_param( self.load_from_fp32_weights = get_scalar_param(
zero_config_dict, zero_config_dict,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT) ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT)
self.cpu_offload = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_CPU_OFFLOAD,
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT)
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
#########################################
# ZeRO optimization
#########################################
# ZeRO optimization. By default, this optimization is not enabled.
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
ZERO_FORMAT = '''
ZeRO optimization should be enabled as:
"session_params": {
"zero_optimization": {
"stage": [0|1|2],
"allgather_partitions": [true|false],
"allgather_bucket_size": 500000000,
"reduce_scatter": [true|false],
"contiguous_gradients" : [true|false]
"overlap_comm": [true|false],
"reduce_bucket_size": 500000000
"load_from_fp32_weights": [true|false]
"cpu_offload": [true|false]
}
}
'''
ZERO_OPTIMIZATION = 'zero_optimization'
ZERO_OPTIMIZATION_DISABLED = 0
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
ZERO_OPTIMIZATION_GRADIENTS = 2
ZERO_OPTIMIZATION_WEIGHTS = 3
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
ZERO_OPTIMIZATION_STAGE = 'stage'
ZERO_OPTIMIZATION_STAGE_1 = 'stage_1'
ZERO_OPTIMIZATION_STAGE_2 = 'stage_2'
ZERO_OPTIMIZATION_STAGE_3 = 'stage_3'
ZERO_OPTIMIZATION_STAGE_DEFAULT = ZERO_OPTIMIZATION_DISABLED
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS = 'allgather_partitions'
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT = True
ZERO_OPTIMIZATION_REDUCE_SCATTER = 'reduce_scatter'
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT = True
ZERO_OPTIMIZATION_OVERLAP_COMM = 'overlap_comm'
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT = False
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS = 'contiguous_gradients'
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT = False
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE = 'reduce_bucket_size'
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE = 'allgather_bucket_size'
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED = 'allgather_size'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS = 'load_from_fp32_weights'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True
ZERO_OPTIMIZATION_CPU_OFFLOAD = 'cpu_offload'
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT = False
ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT,
ZERO_OPTIMIZATION_CPU_OFFLOAD: ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT
}
...@@ -793,8 +793,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -793,8 +793,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
def _get_state_without_padding(self, state_with_padding, padding): def _get_state_without_padding(self, state_with_padding, padding):
lean_state = {} lean_state = {}
for key, value in state_with_padding.items(): for key, value in state_with_padding.items():
lean_length = value.numel() - padding if torch.is_tensor(value):
lean_state[key] = value[:lean_length] lean_length = value.numel() - padding
lean_state[key] = value[:lean_length]
else:
lean_state[key] = value
return lean_state return lean_state
......
''' '''
Copyright 2019 The Microsoft DeepSpeed Team Copyright 2019 The Microsoft DeepSpeed Team
''' '''
import torch import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed.distributed_c10d import _get_global_rank from torch.distributed.distributed_c10d import _get_global_rank
import torch.distributed as dist import torch.distributed as dist
import math import math
from torch._six import inf from torch._six import inf
from torch.autograd import Variable from torch.autograd import Variable
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler import collections
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.utils import logger from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
#Toggle this to true to enable correctness test from deepspeed.ops.adam import DeepSpeedCPUAdam
#with gradient partitioning and without
pg_correctness_test = False from deepspeed.utils import logger
#Toggle this to true to enable correctness test
try: #with gradient partitioning and without
from apex_C import flatten pg_correctness_test = False
from apex_C import unflatten
except ImportError: try:
try: from apex_C import flatten
_ = warned_flatten from apex_C import unflatten
except NameError: except ImportError:
logger.warning( try:
"apex was installed without --cpp_ext. Falling back to Python flatten and unflatten." _ = warned_flatten
) except NameError:
warned_flatten = True logger.warning(
from torch._utils import _flatten_dense_tensors as flatten "apex was installed without --cpp_ext. Falling back to Python flatten and unflatten."
from torch._utils import _unflatten_dense_tensors as unflatten )
warned_flatten = True
from torch._utils import _flatten_dense_tensors as flatten
def input(msg): from torch._utils import _unflatten_dense_tensors as unflatten
return
def input(msg):
def split_half_float_double(tensors): return
dtypes = [
"torch.cuda.HalfTensor",
"torch.cuda.FloatTensor", def split_half_float_double(tensors):
"torch.cuda.DoubleTensor" dtypes = [
] "torch.cuda.HalfTensor",
buckets = [] "torch.cuda.FloatTensor",
for i, dtype in enumerate(dtypes): "torch.cuda.DoubleTensor"
bucket = [t for t in tensors if t.type() == dtype] ]
if bucket: buckets = []
buckets.append(bucket) for i, dtype in enumerate(dtypes):
return buckets bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
def isclose(a, b, rtol=1e-09, atol=0.0): return buckets
return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)
def isclose(a, b, rtol=1e-09, atol=0.0):
def lcm(x, y): return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)
from fractions import gcd # or can import gcd from `math` in Python 3
return x * y // gcd(x, y)
def lcm(x, y):
from fractions import gcd # or can import gcd from `math` in Python 3
# create a flat tensor aligned at the alignment boundary return x * y // gcd(x, y)
def flatten_dense_tensors_aligned(tensor_list, alignment):
num_elements = 0
for tensor in tensor_list: # create a flat tensor aligned at the alignment boundary
num_elements = num_elements + tensor.numel() def flatten_dense_tensors_aligned(tensor_list, alignment):
num_elements = 0
remaining = num_elements % alignment for tensor in tensor_list:
num_elements = num_elements + tensor.numel()
if remaining:
elements_to_add = alignment - remaining remaining = num_elements % alignment
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device, if remaining:
dtype=tensor_list[0].dtype) elements_to_add = alignment - remaining
padded_tensor_list = tensor_list + [pad_tensor] pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
num_elements = num_elements + elements_to_add dtype=tensor_list[0].dtype)
else: padded_tensor_list = tensor_list + [pad_tensor]
padded_tensor_list = tensor_list
num_elements = num_elements + elements_to_add
return _flatten_dense_tensors(padded_tensor_list) else:
padded_tensor_list = tensor_list
def get_alignment_padding(tensor_list, alignment): return _flatten_dense_tensors(padded_tensor_list)
num_elements = sum([tensor.numel() for tensor in tensor_list])
remainder = num_elements % alignment
return (alignment - remainder) if remainder else remainder def get_alignment_padding(tensor_list, alignment):
num_elements = sum([tensor.numel() for tensor in tensor_list])
remainder = num_elements % alignment
def move_to_cpu(tensor_list): return (alignment - remainder) if remainder else remainder
for tensor in tensor_list:
tensor.data = tensor.data.cpu()
def move_to_cpu(tensor_list):
for tensor in tensor_list:
def print_rank_msg(msg): tensor.data = tensor.data.cpu()
print(f"rank {dist.get_rank()} - {msg}")
def print_rank_msg(msg):
class FP16_DeepSpeedZeroOptimizer(object): print(f"rank {dist.get_rank()} - {msg}")
"""
DeepSpeedZeroOptimizer designed to reduce the memory footprint
required for training large deep learning models. class FP16_DeepSpeedZeroOptimizer(object):
"""
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models DeepSpeedZeroOptimizer designed to reduce the memory footprint
https://arxiv.org/abs/1910.02054 required for training large deep learning models.
For usage examples, refer to TODO: DeepSpeed Tutorial For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
"""
def __init__(self, For usage examples, refer to TODO: DeepSpeed Tutorial
init_optimizer,
timers, """
static_loss_scale=1.0, def __init__(self,
dynamic_loss_scale=False, init_optimizer,
dynamic_loss_args=None, timers,
verbose=True, static_loss_scale=1.0,
contiguous_gradients=True, dynamic_loss_scale=False,
reduce_bucket_size=500000000, dynamic_loss_args=None,
allgather_bucket_size=5000000000, verbose=True,
dp_process_group=None, contiguous_gradients=True,
reduce_scatter=True, reduce_bucket_size=500000000,
overlap_comm=False, allgather_bucket_size=5000000000,
mpu=None, dp_process_group=None,
clip_grad=0.0, reduce_scatter=True,
allreduce_always_fp32=False, overlap_comm=False,
postscale_gradients=True, cpu_offload=False,
gradient_predivide_factor=1.0): mpu=None,
clip_grad=0.0,
if dist.get_rank() == 0: allreduce_always_fp32=False,
logger.info(f"Reduce bucket size {reduce_bucket_size}") postscale_gradients=True,
logger.info(f"Allgather bucket size {allgather_bucket_size}") gradient_predivide_factor=1.0,
# The fused optimizer does all the work. We need this layer for two reason: gradient_accumulation_steps=1):
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add ne552w fused optimizer later if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
# differences from apex.fp16_utils: logger.info(f"Allgather bucket size {allgather_bucket_size}")
# - assume all model params in fp16 logger.info(f"CPU Offload: {cpu_offload}")
# - assume all params requires grad # The fused optimizer does all the work. We need this layer for two reason:
# - flat by groups, not keeping state. TODO: remove state explicitly? # 1. maintain same user API from apex.fp16_utils
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master? # 2. keep common stuff here in case we need to add ne552w fused optimizer later
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.") # differences from apex.fp16_utils:
self.optimizer = init_optimizer # - assume all model params in fp16
# - assume all params requires grad
self.timers = timers # - flat by groups, not keeping state. TODO: remove state explicitly?
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
self.reduce_scatter = reduce_scatter if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.overlap_comm = overlap_comm self.optimizer = init_optimizer
self.dp_process_group = dp_process_group self.timers = timers
self.partition_count = dist.get_world_size(group=self.dp_process_group) self.reduce_scatter = reduce_scatter
if mpu is None: self.overlap_comm = overlap_comm
self.model_parallel_group = None
self.model_parallel_rank = 0 self.cpu_offload = cpu_offload
else:
self.model_parallel_group = mpu.get_model_parallel_group() self.deepspeed_adam_offload = (cpu_offload
self.model_parallel_rank = mpu.get_model_parallel_rank() and type(init_optimizer) == DeepSpeedCPUAdam)
self.overflow = False self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu'
self.clip_grad = clip_grad
self.allreduce_always_fp32 = allreduce_always_fp32 self.dp_process_group = dp_process_group
self.gradient_predivide_factor = gradient_predivide_factor
self.postscale_gradients = postscale_gradients self.partition_count = dist.get_world_size(group=self.dp_process_group)
if self.reduce_scatter: if mpu is None:
assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled" self.model_parallel_group = None
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled" self.model_parallel_rank = 0
assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled" else:
self.model_parallel_group = mpu.get_model_parallel_group()
# param flattened by groups self.model_parallel_rank = mpu.get_model_parallel_rank()
self.fp16_groups = []
self.fp16_groups_flat = [] self.overflow = False
self.clip_grad = clip_grad
#param partitioned by data parallel degree self.allreduce_always_fp32 = allreduce_always_fp32
#this will contain a list of equal sized tensors self.gradient_predivide_factor = gradient_predivide_factor
#each of which will be updated by a different process self.postscale_gradients = postscale_gradients
self.parallel_partitioned_fp16_groups = [] self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = 0
#a single 32-bit partition of the parallel partitioned parameters
#that this process will update if self.reduce_scatter:
self.single_partition_of_fp32_groups = [] assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled"
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled"
#param partition info assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled"
#These are the parameters in each group that will not be updated by this process directly # param flattened by groups
self.params_not_in_partition = [] self.fp16_groups = []
self.fp16_groups_flat = []
#These are the parameters that will be updated by this process directly
self.params_in_partition = [] #param partitioned by data parallel degree
#this will contain a list of equal sized tensors
#Offset from the first paramter in the the self.params_in_partition #each of which will be updated by a different process
#the parameter boundaries may not align with partition boundaries self.parallel_partitioned_fp16_groups = []
#so we need to keep track of the offset
self.first_offset = [] #a single 32-bit partition of the parallel partitioned parameters
#that this process will update
#number of elements per partition in each group self.single_partition_of_fp32_groups = []
self.partition_size = []
#param partition info
partition_id = dist.get_rank(group=self.dp_process_group)
#These are the parameters in each group that will not be updated by this process directly
self.all_reduce_print = False self.params_not_in_partition = []
# padding on each partition for alignment purposes #These are the parameters that will be updated by this process directly
self.groups_padding = [] self.params_in_partition = []
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups): #Offset from the first paramter in the the self.params_in_partition
# push this group to list before modify #the parameter boundaries may not align with partition boundaries
self.fp16_groups.append(param_group['params']) #so we need to keep track of the offset
# Record padding required to align group to world size self.first_offset = []
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
padding = get_alignment_padding(self.fp16_groups[i], #number of elements per partition in each group
self.partition_count) self.partition_size = []
else:
padding = 0 partition_id = dist.get_rank(group=self.dp_process_group)
self.groups_padding.append(padding)
self.all_reduce_print = False
#not sure why apex was cloning the weights before flattening
#removing cloning here # padding on each partition for alignment purposes
self.groups_padding = []
see_memory_usage(f"Before moving param group {i} to CPU") # loop to deal with groups
#move all the parameters to cpu to free up GPU space for creating flat buffer for i, param_group in enumerate(self.optimizer.param_groups):
move_to_cpu(self.fp16_groups[i]) # push this group to list before modify
see_memory_usage(f"After moving param group {i} to CPU") self.fp16_groups.append(param_group['params'])
# Record padding required to align group to world size
#create flat buffer in CPU and move to GPU if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
self.fp16_groups_flat.append( padding = get_alignment_padding(self.fp16_groups[i],
flatten_dense_tensors_aligned( self.partition_count)
self.fp16_groups[i], else:
dist.get_world_size(group=self.dp_process_group)).cuda( padding = 0
torch.cuda.current_device())) self.groups_padding.append(padding)
see_memory_usage(f"After flattening and moving param group {i} to GPU")
#not sure why apex was cloning the weights before flattening
if dist.get_rank(group=self.dp_process_group) == 0: #removing cloning here
see_memory_usage(
f"After Flattening and after emptying param group {i} cache") see_memory_usage(f"Before moving param group {i} to CPU")
#move all the parameters to cpu to free up GPU space for creating flat buffer
# set model fp16 weight to slices of flattened buffer move_to_cpu(self.fp16_groups[i])
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], see_memory_usage(f"After moving param group {i} to CPU")
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params): #create flat buffer in CPU and move to GPU
p.data = q.data self.fp16_groups_flat.append(
flatten_dense_tensors_aligned(
#divide the flat weights into near equal paritition equal to the data parallel degree self.fp16_groups[i],
#each process will compute on a different part of the partition dist.get_world_size(group=self.dp_process_group)).cuda(
data_parallel_partitions = self.get_data_parallel_partitions( torch.cuda.current_device()))
self.fp16_groups_flat[i]) see_memory_usage(f"After flattening and moving param group {i} to GPU")
self.parallel_partitioned_fp16_groups.append(data_parallel_partitions)
if dist.get_rank(group=self.dp_process_group) == 0:
# a partition of the fp32 master weights that will be updated by this process see_memory_usage(
self.single_partition_of_fp32_groups.append( f"After Flattening and after emptying param group {i} cache")
self.parallel_partitioned_fp16_groups[i]
[partition_id].clone().float().detach()) # set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
# modify optimizer of have flat master weight self.fp16_groups[i])
self.single_partition_of_fp32_groups[ for p, q in zip(self.fp16_groups[i], updated_params):
i].requires_grad = True # keep this in case internal optimizer uses it p.data = q.data
param_group['params'] = [self.single_partition_of_fp32_groups[i]]
#divide the flat weights into near equal paritition equal to the data parallel degree
partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size( #each process will compute on a different part of the partition
group=self.dp_process_group) data_parallel_partitions = self.get_data_parallel_partitions(
params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(self.fp16_groups[i], partition_size, partition_id) self.fp16_groups_flat[i])
self.parallel_partitioned_fp16_groups.append(data_parallel_partitions)
self.partition_size.append(partition_size)
self.params_in_partition.append(params_in_partition) # a partition of the fp32 master weights that will be updated by this process
self.params_not_in_partition.append(params_not_in_partition) self.single_partition_of_fp32_groups.append(
self.first_offset.append(first_offset) self.parallel_partitioned_fp16_groups[i][partition_id].to(
self.device).clone().float().detach())
self.reduce_bucket_size = int(reduce_bucket_size)
self.allgather_bucket_size = int(allgather_bucket_size) # modify optimizer of have flat master weight
self.single_partition_of_fp32_groups[
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) i].requires_grad = True # keep this in case internal optimizer uses it
self.reduction_stream = torch.cuda.Stream() param_group['params'] = [self.single_partition_of_fp32_groups[i]]
self.callback_queued = False
partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(
self.param_dict = {} group=self.dp_process_group)
params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(self.fp16_groups[i], partition_size, partition_id)
#map between param_id and bool to specify if a param is in this partition
self.is_param_in_current_partition = {} self.partition_size.append(partition_size)
self.params_in_partition.append(params_in_partition)
self.contiguous_gradients = contiguous_gradients self.params_not_in_partition.append(params_not_in_partition)
self.grads_in_ipg_bucket = [] self.first_offset.append(first_offset)
self.params_in_ipg_bucket = []
self.elements_in_ipg_bucket = 0 self.reduce_bucket_size = int(reduce_bucket_size)
self.params_already_reduced = [] self.allgather_bucket_size = int(allgather_bucket_size)
self._release_ipg_buffers()
self.previous_reduced_grads = None self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.reduction_stream = torch.cuda.Stream()
#simplified param id self.cpu_computation_stream = torch.cuda.Stream()
self.param_id = {} self.migration_stream = torch.cuda.Stream()
self.callback_queued = False
count = 0
for i, params_group in enumerate(self.fp16_groups): self.param_dict = {}
for param in params_group:
unique_id = id(param) #map between param_id and bool to specify if a param is in this partition
self.param_id[unique_id] = count self.is_param_in_current_partition = {}
self.param_dict[count] = param
self.params_already_reduced.append(False) # CPU-Offload requires contiguous gradients
count = count + 1 self.contiguous_gradients = contiguous_gradients or cpu_offload
self.grads_in_ipg_bucket = []
for param_group in self.params_in_partition: self.params_in_ipg_bucket = []
for param in param_group: self.elements_in_ipg_bucket = 0
self.is_param_in_current_partition[self.get_param_id(param)] = True self.params_already_reduced = []
self._release_ipg_buffers()
for param_group in self.params_not_in_partition: self.previous_reduced_grads = None
for param in param_group:
self.is_param_in_current_partition[self.get_param_id(param)] = False #simplified param id
self.param_id = {}
#mapping from parameter to partition that it belongs to
self.param_to_partition_ids = {} largest_param_numel = 0
count = 0
#stores if a partition has been reduced in this step for i, params_group in enumerate(self.fp16_groups):
self.is_partition_reduced = {} for param in params_group:
unique_id = id(param)
#number of grads in partition that still need to be computed self.param_id[unique_id] = count
self.remaining_grads_in_partition = {} self.param_dict[count] = param
self.params_already_reduced.append(False)
#total number of grads in partition if param.numel() > largest_param_numel:
self.total_grads_in_partition = {} largest_param_numel = param.numel()
count = count + 1
#stores if a grad in a partition has been computed or not
self.is_grad_computed = {} for param_group in self.params_in_partition:
for param in param_group:
#stores the offset at which a parameter gradient needs to be inserted in a partition self.is_param_in_current_partition[self.get_param_id(param)] = True
self.grad_partition_insertion_offset = {}
for param_group in self.params_not_in_partition:
#the offset in the gradient at which it must be inserted at the beginning of the paritition for param in param_group:
self.grad_start_offset = {} self.is_param_in_current_partition[self.get_param_id(param)] = False
#will store the averaged gradients required by this parititon if self.cpu_offload:
self.averaged_gradients = {} self.accumulated_grads_in_cpu = {}
self.norm_for_param_grads = {}
# store index of first parameter in each partition self.local_overflow = False
self.first_param_index_in_partition = {} self.grad_position = {}
self.temp_grad_buffer_for_cpu_offload = torch.zeros(
#initializes all data structures for implementing gradient partitioning largest_param_numel,
self.initialize_gradient_partitioning_data_structures() device=self.device).half().pin_memory()
self.temp_grad_buffer_for_gpu_offload = torch.zeros(
#resets the data structure value for the next backward propagation largest_param_numel,
self.reset_partition_gradient_structures() device=torch.cuda.current_device()).half()
#creates backward hooks for gradient partitioning for i, params_group in enumerate(self.fp16_groups):
self.create_reduce_and_remove_grad_hooks() self.get_grad_position(i,
self.params_in_partition[i],
# we may have a way of fusing dynamic scale. Do not support for now self.first_offset[i],
if dynamic_loss_scale: self.partition_size[i])
if dynamic_loss_args is None:
self.loss_scaler = DynamicLossScaler() #mapping from parameter to partition that it belongs to
else: self.param_to_partition_ids = {}
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
#stores if a partition has been reduced in this step
self.dynamic_loss_scale = True self.is_partition_reduced = {}
else: #number of grads in partition that still need to be computed
self.dynamic_loss_scale = False self.remaining_grads_in_partition = {}
self.loss_scaler = LossScaler(scale=static_loss_scale)
self.cur_iter = 0 #total number of grads in partition
self.total_grads_in_partition = {}
see_memory_usage("Before initializing optimizer states")
self.initialize_optimizer_states() #stores if a grad in a partition has been computed or not
see_memory_usage("After initializing optimizer states") self.is_grad_computed = {}
if dist.get_rank() == 0: #stores the offset at which a parameter gradient needs to be inserted in a partition
logger.info(f"optimizer state initialized") self.grad_partition_insertion_offset = {}
if dist.get_rank(group=self.dp_process_group) == 0: #the offset in the gradient at which it must be inserted at the beginning of the paritition
see_memory_usage(f"After initializing ZeRO optimizer") self.grad_start_offset = {}
def _release_ipg_buffers(self): #will store the averaged gradients required by this parititon
if self.contiguous_gradients: self.averaged_gradients = {}
self.ipg_buffer = None
self.grads_in_partition = None # store index of first parameter in each partition
self.grads_in_partition_offset = 0 self.first_param_index_in_partition = {}
def initialize_optimizer_states(self): #initializes all data structures for implementing gradient partitioning
self.initialize_gradient_partitioning_data_structures()
for i, group in enumerate(self.fp16_groups):
single_grad_partition = torch.zeros( #resets the data structure value for the next backward propagation
int(self.partition_size[i]), self.reset_partition_gradient_structures()
dtype=self.single_partition_of_fp32_groups[i].dtype,
device=torch.cuda.current_device()) #creates backward hooks for gradient partitioning
self.single_partition_of_fp32_groups[i].grad = single_grad_partition self.create_reduce_and_remove_grad_hooks()
self.optimizer.step() # we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
for group in self.single_partition_of_fp32_groups: if dynamic_loss_args is None:
group.grad = None self.loss_scaler = DynamicLossScaler()
else:
return self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
######################################################################### self.dynamic_loss_scale = True
#########################ZeRO Partition Gradients########################
######################################################################### else:
self.dynamic_loss_scale = False
def get_first_param_index(self, group_id, param_group, partition_id): self.loss_scaler = LossScaler(scale=static_loss_scale)
for index, param in enumerate(param_group): self.cur_iter = 0
param_id = self.get_param_id(param)
if partition_id in self.param_to_partition_ids[group_id][param_id]: see_memory_usage("Before initializing optimizer states")
return index self.initialize_optimizer_states()
return None see_memory_usage("After initializing optimizer states")
def initialize_gradient_partitioning_data_structures(self): if dist.get_rank() == 0:
logger.info(f"optimizer state initialized")
total_partitions = dist.get_world_size(group=self.dp_process_group)
if dist.get_rank(group=self.dp_process_group) == 0:
for i, param_group in enumerate(self.fp16_groups): see_memory_usage(f"After initializing ZeRO optimizer")
self.param_to_partition_ids[i] = {} def _release_ipg_buffers(self):
self.is_partition_reduced[i] = {} if self.contiguous_gradients:
self.total_grads_in_partition[i] = {} self.ipg_buffer = None
self.remaining_grads_in_partition[i] = {} self.grads_in_partition = None
self.is_grad_computed[i] = {} self.grads_in_partition_offset = 0
self.grad_partition_insertion_offset[i] = {}
self.grad_start_offset[i] = {} def initialize_optimizer_states(self):
self.first_param_index_in_partition[i] = {}
for i, group in enumerate(self.fp16_groups):
for partition_id in range(total_partitions): single_grad_partition = torch.zeros(
self.is_grad_computed[i][partition_id] = {} int(self.partition_size[i]),
self.grad_partition_insertion_offset[i][partition_id] = {} dtype=self.single_partition_of_fp32_groups[i].dtype,
self.grad_start_offset[i][partition_id] = {} device=self.device)
self.total_grads_in_partition[i][partition_id] = 0 self.single_partition_of_fp32_groups[
self.initialize_gradient_partition(i, param_group, partition_id) i].grad = single_grad_partition.pin_memory(
self.is_partition_reduced[i][partition_id] = False ) if self.cpu_offload else single_grad_partition
self.first_param_index_in_partition[i][
partition_id] = self.get_first_param_index( self.optimizer.step()
i,
param_group, if not self.cpu_offload:
partition_id) for group in self.single_partition_of_fp32_groups:
group.grad = None
def independent_gradient_partition_epilogue(self):
self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) return
self.reduce_ipg_grads()
self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) #########################################################################
#########################ZeRO Partition Gradients########################
#if dist.get_rank() == 0: #########################################################################
# logger.info("Params already reduced %s", self.params_already_reduced)
for i in range(len(self.params_already_reduced)): def get_first_param_index(self, group_id, param_group, partition_id):
self.params_already_reduced[i] = False for index, param in enumerate(param_group):
param_id = self.get_param_id(param)
if self.overlap_comm: if partition_id in self.param_to_partition_ids[group_id][param_id]:
torch.cuda.synchronize() return index
return None
for i, _ in enumerate(self.fp16_groups):
if not i in self.averaged_gradients or self.averaged_gradients[i] is None: def initialize_gradient_partitioning_data_structures(self):
self.averaged_gradients[i] = self.get_flat_partition(
self.params_in_partition[i], total_partitions = dist.get_world_size(group=self.dp_process_group)
self.first_offset[i],
self.partition_size[i], for i, param_group in enumerate(self.fp16_groups):
dtype=torch.half,
device=torch.cuda.current_device(), self.param_to_partition_ids[i] = {}
return_tensor_list=True) self.is_partition_reduced[i] = {}
else: self.total_grads_in_partition[i] = {}
#When gradient accumulation is greater that 1 self.remaining_grads_in_partition[i] = {}
#This code path will be triggered and will add self.is_grad_computed[i] = {}
#to the accumulated averaged gradients self.grad_partition_insertion_offset[i] = {}
avg_new = self.get_flat_partition(self.params_in_partition[i], self.grad_start_offset[i] = {}
self.first_offset[i], self.first_param_index_in_partition[i] = {}
self.partition_size[i],
dtype=torch.half, for partition_id in range(total_partitions):
device=torch.cuda.current_device(), self.is_grad_computed[i][partition_id] = {}
return_tensor_list=True) self.grad_partition_insertion_offset[i][partition_id] = {}
self.grad_start_offset[i][partition_id] = {}
for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i],avg_new): self.total_grads_in_partition[i][partition_id] = 0
accumulated_grad.add_(new_avg_grad) self.initialize_gradient_partition(i, param_group, partition_id)
self.is_partition_reduced[i][partition_id] = False
self._release_ipg_buffers() self.first_param_index_in_partition[i][
partition_id] = self.get_first_param_index(
# No need to keep the gradients anymore. i,
# All gradients required by the step param_group,
# are in self.averaged_gradients partition_id)
self.zero_grad()
see_memory_usage(f"End ipg_epilogue") def independent_gradient_partition_epilogue(self):
self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0)
# resets all partition to no reduced self.reduce_ipg_grads()
# sets remianing grads to the total number of grads in each partition self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0)
# set is grad computed to false for all grads in partition
def reset_partition_gradient_structures(self): #if dist.get_rank() == 0:
total_partitions = dist.get_world_size(group=self.dp_process_group) # logger.info("Params already reduced %s", self.params_already_reduced)
for i, _ in enumerate(self.fp16_groups): for i in range(len(self.params_already_reduced)):
for partition_id in range(total_partitions): self.params_already_reduced[i] = False
self.is_partition_reduced[i][partition_id] = False
self.remaining_grads_in_partition[i][ if self.overlap_comm:
partition_id] = self.total_grads_in_partition[i][partition_id] torch.cuda.synchronize()
for param_id in self.is_grad_computed[i][partition_id]: if self.cpu_offload is False:
self.is_grad_computed[i][partition_id][param_id] = False for i, _ in enumerate(self.fp16_groups):
def initialize_gradient_partition(self, i, param_group, partition_id): if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
def set_key_value_list(dictionary, key, value): self.averaged_gradients[i] = self.get_flat_partition(
if key in dictionary: self.params_in_partition[i],
dictionary[key].append(value) self.first_offset[i],
else: self.partition_size[i],
dictionary[key] = [value] dtype=torch.half,
device=torch.cuda.current_device(),
def increment_value(dictionary, key): return_tensor_list=True)
if key in dictionary: else:
dictionary[key] += 1 avg_new = self.get_flat_partition(self.params_in_partition[i],
else: self.first_offset[i],
dictionary[key] = 1 self.partition_size[i],
dtype=torch.half,
partition_size = self.partition_size[i] device=torch.cuda.current_device(),
return_tensor_list=True)
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1) for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i],avg_new):
accumulated_grad.add_(new_avg_grad)
current_index = 0
first_offset = 0 self._release_ipg_buffers()
for param in param_group: # No need to keep the gradients anymore.
# All gradients required by the step
param_size = param.numel() # are in self.averaged_gradients
param_id = self.get_param_id(param) self.zero_grad()
see_memory_usage(f"End ipg_epilogue")
if (current_index >= start_index and current_index < end_index):
set_key_value_list(self.param_to_partition_ids[i], # resets all partition to no reduced
param_id, # sets remianing grads to the total number of grads in each partition
partition_id) # set is grad computed to false for all grads in partition
increment_value(self.total_grads_in_partition[i], partition_id) def reset_partition_gradient_structures(self):
total_partitions = dist.get_world_size(group=self.dp_process_group)
self.is_grad_computed[i][partition_id][param_id] = False for i, _ in enumerate(self.fp16_groups):
for partition_id in range(total_partitions):
self.grad_partition_insertion_offset[i][partition_id][ self.is_partition_reduced[i][partition_id] = False
param_id] = current_index - start_index self.remaining_grads_in_partition[i][
self.grad_start_offset[i][partition_id][param_id] = 0 partition_id] = self.total_grads_in_partition[i][partition_id]
elif start_index > current_index and start_index < (current_index + for param_id in self.is_grad_computed[i][partition_id]:
param_size): self.is_grad_computed[i][partition_id][param_id] = False
assert (first_offset==0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index def initialize_gradient_partition(self, i, param_group, partition_id):
def set_key_value_list(dictionary, key, value):
set_key_value_list(self.param_to_partition_ids[i], if key in dictionary:
param_id, dictionary[key].append(value)
partition_id) else:
increment_value(self.total_grads_in_partition[i], partition_id) dictionary[key] = [value]
self.is_grad_computed[i][partition_id][param_id] = False def increment_value(dictionary, key):
if key in dictionary:
self.grad_partition_insertion_offset[i][partition_id][param_id] = 0 dictionary[key] += 1
self.grad_start_offset[i][partition_id][param_id] = first_offset else:
dictionary[key] = 1
current_index = current_index + param_size
partition_size = self.partition_size[i]
def overlapping_partition_gradients_reduce_epilogue(self):
self.independent_gradient_partition_epilogue() start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
def create_reduce_and_remove_grad_hooks(self):
self.grad_accs = [] current_index = 0
for i, param_group in enumerate(self.fp16_groups): first_offset = 0
for param in param_group:
if param.requires_grad: for param in param_group:
def wrapper(param, i): param_size = param.numel()
param_tmp = param.expand_as(param) param_id = self.get_param_id(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
if (current_index >= start_index and current_index < end_index):
def reduce_partition_and_remove_grads(*notneeded): set_key_value_list(self.param_to_partition_ids[i],
self.reduce_ready_partitions_and_remove_grads(param, i) param_id,
partition_id)
grad_acc.register_hook(reduce_partition_and_remove_grads) increment_value(self.total_grads_in_partition[i], partition_id)
self.grad_accs.append(grad_acc)
self.is_grad_computed[i][partition_id][param_id] = False
wrapper(param, i)
self.grad_partition_insertion_offset[i][partition_id][
def get_param_id(self, param): param_id] = current_index - start_index
unique_id = id(param) self.grad_start_offset[i][partition_id][param_id] = 0
return self.param_id[unique_id]
elif start_index > current_index and start_index < (current_index +
def report_ipg_memory_usage(self, tag, param_elems): param_size):
elem_count = self.elements_in_ipg_bucket + param_elems assert (first_offset==0), "This can happen either zero or only once as this must be the first tensor in the partition"
percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size first_offset = start_index - current_index
see_memory_usage(
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}" set_key_value_list(self.param_to_partition_ids[i],
) param_id,
partition_id)
###############Idependent Partition Gradient ######################## increment_value(self.total_grads_in_partition[i], partition_id)
def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size: self.is_grad_computed[i][partition_id][param_id] = False
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads",
param.numel()) self.grad_partition_insertion_offset[i][partition_id][param_id] = 0
self.reduce_ipg_grads() self.grad_start_offset[i][partition_id][param_id] = first_offset
if self.contiguous_gradients and self.overlap_comm:
# Swap ipg_index between 0 and 1 current_index = current_index + param_size
self.ipg_index = 1 - self.ipg_index
self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", def overlapping_partition_gradients_reduce_epilogue(self):
param.numel()) self.independent_gradient_partition_epilogue()
param_id = self.get_param_id(param) def create_reduce_and_remove_grad_hooks(self):
self.grad_accs = []
assert self.params_already_reduced[param_id] == False, \ for i, param_group in enumerate(self.fp16_groups):
f"The parameter {param_id} has already been reduced. \ for param in param_group:
Gradient computed twice for this partition. \ if param.requires_grad:
Multiple gradient reduction is currently not supported"
def wrapper(param, i):
#keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening param_tmp = param.expand_as(param)
if self.contiguous_gradients: grad_acc = param_tmp.grad_fn.next_functions[0][0]
new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(
0, def reduce_partition_and_remove_grads(*notneeded):
self.elements_in_ipg_bucket, self.reduce_ready_partitions_and_remove_grads(param, i)
param.numel())
new_grad_tensor.copy_(param.grad.view(-1)) grad_acc.register_hook(reduce_partition_and_remove_grads)
param.grad.data = new_grad_tensor.data.view_as(param.grad) self.grad_accs.append(grad_acc)
self.elements_in_ipg_bucket += param.numel() wrapper(param, i)
self.grads_in_ipg_bucket.append(param.grad)
self.params_in_ipg_bucket.append((i, param, param_id)) def get_param_id(self, param):
unique_id = id(param)
self.report_ipg_memory_usage("End ipg_remove_grads", 0) return self.param_id[unique_id]
def print_rank_0(self, message): def report_ipg_memory_usage(self, tag, param_elems):
if dist.get_rank() == 0: elem_count = self.elements_in_ipg_bucket + param_elems
logger.info(message) percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size
see_memory_usage(
def gradient_reduction_w_predivide(self, tensor): f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
dp_world_size = dist.get_world_size(group=self.dp_process_group) )
tensor_to_allreduce = tensor ###############Idependent Partition Gradient ########################
def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
if self.allreduce_always_fp32: if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:
tensor_to_allreduce = tensor.float() self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads",
param.numel())
if self.postscale_gradients: self.reduce_ipg_grads()
if self.gradient_predivide_factor != 1.0: if self.contiguous_gradients and self.overlap_comm:
tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor) # Swap ipg_index between 0 and 1
self.ipg_index = 1 - self.ipg_index
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads",
param.numel())
if self.gradient_predivide_factor() != dp_world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor() / param_id = self.get_param_id(param)
dp_world_size)
else: assert self.params_already_reduced[param_id] == False, \
tensor_to_allreduce.div_(dp_world_size) f"The parameter {param_id} has already been reduced. \
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported"
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce) #keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
if self.contiguous_gradients:
return tensor new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(
0,
def average_tensor(self, tensor): self.elements_in_ipg_bucket,
if self.overlap_comm: param.numel())
torch.cuda.synchronize() new_grad_tensor.copy_(param.grad.view(-1))
stream = self.reduction_stream param.grad.data = new_grad_tensor.data.view_as(param.grad)
else:
stream = torch.cuda.current_stream() self.elements_in_ipg_bucket += param.numel()
self.grads_in_ipg_bucket.append(param.grad)
with torch.cuda.stream(stream): self.params_in_ipg_bucket.append((i, param, param_id))
if not self.reduce_scatter:
self.gradient_reduction_w_predivide(tensor) self.report_ipg_memory_usage("End ipg_remove_grads", 0)
return
def print_rank_0(self, message):
# Accumulate destination ranks and bucket offsets for each gradient slice. if dist.get_rank() == 0:
# Note: potential future optimization, record access pattern of parameters logger.info(message)
# in backward pass and partition gradients w.r.t. access pattern so that our
# bucket is guaranteed to be contiguous w.r.t. ranks def gradient_reduction_w_predivide(self, tensor):
rank_and_offsets = [] dp_world_size = dist.get_world_size(group=self.dp_process_group)
curr_size = 0
prev_id = -1 tensor_to_allreduce = tensor
for i, param, param_id in self.params_in_ipg_bucket:
partition_ids = self.param_to_partition_ids[i][param_id] if self.allreduce_always_fp32:
partition_size = self.partition_size[i] tensor_to_allreduce = tensor.float()
# Get all partition ids + their offsets
partition_ids_w_offsets = [] if self.postscale_gradients:
for partition_id in partition_ids: if self.gradient_predivide_factor != 1.0:
offset = self.grad_start_offset[i][partition_id][param_id] tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)
partition_ids_w_offsets.append((partition_id, offset))
partition_ids_w_offsets.sort(key=lambda t: t[1]) dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
# Calculate rank and offsets for grad slices if self.gradient_predivide_factor != dp_world_size:
for idx in range(len(partition_ids_w_offsets)): tensor_to_allreduce.mul_(self.gradient_predivide_factor / dp_world_size)
partition_id, offset = partition_ids_w_offsets[idx] else:
tensor_to_allreduce.div_(dp_world_size)
# Calculate numel for grad slice depending on partition location dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if idx == len(partition_ids_w_offsets) - 1:
# Last partition_id uses its own offset if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
numel = param.numel() - offset tensor.copy_(tensor_to_allreduce)
else:
# Set numel to next partition's offset return tensor
numel = partition_ids_w_offsets[idx + 1][1] - offset
def average_tensor(self, tensor):
# Merge bucket ranges if they belong to the same rank if self.overlap_comm:
if partition_id == prev_id: torch.cuda.synchronize()
prev_pid, prev_size, prev_numel = rank_and_offsets[-1] stream = self.reduction_stream
rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + numel) else:
else: stream = torch.cuda.current_stream()
rank_and_offsets.append((partition_id, curr_size, numel))
with torch.cuda.stream(stream):
curr_size += numel if not self.reduce_scatter:
prev_id = partition_id self.gradient_reduction_w_predivide(tensor)
tensor.div_(dist.get_world_size(group=self.dp_process_group)) return
async_handles = [] # Accumulate destination ranks and bucket offsets for each gradient slice.
for dst, bucket_offset, numel in rank_and_offsets: # Note: potential future optimization, record access pattern of parameters
grad_slice = tensor.narrow(0, int(bucket_offset), int(numel)) # in backward pass and partition gradients w.r.t. access pattern so that our
dst_rank = _get_global_rank(self.dp_process_group, dst) # bucket is guaranteed to be contiguous w.r.t. ranks
async_handle = dist.reduce(grad_slice, rank_and_offsets = []
dst=dst_rank, curr_size = 0
group=self.dp_process_group, prev_id = -1
async_op=True) for i, param, param_id in self.params_in_ipg_bucket:
async_handles.append(async_handle) partition_ids = self.param_to_partition_ids[i][param_id]
partition_size = self.partition_size[i]
for handle in async_handles: # Get all partition ids + their offsets
handle.wait() partition_ids_w_offsets = []
for partition_id in partition_ids:
def copy_grads_in_partition(self, param): offset = self.grad_start_offset[i][partition_id][param_id]
if self.grads_in_partition is None: partition_ids_w_offsets.append((partition_id, offset))
self.grads_in_partition_offset = 0 partition_ids_w_offsets.sort(key=lambda t: t[1])
total_size = 0
for group in self.params_in_partition: # Calculate rank and offsets for grad slices
for param_in_partition in group: for idx in range(len(partition_ids_w_offsets)):
total_size += param_in_partition.numel() partition_id, offset = partition_ids_w_offsets[idx]
see_memory_usage(f"before copying {total_size} gradients into partition") # Calculate numel for grad slice depending on partition location
self.grads_in_partition = torch.empty(int(total_size), if idx == len(partition_ids_w_offsets) - 1:
dtype=torch.half, # Last partition_id uses its own offset
device=torch.cuda.current_device()) numel = param.numel() - offset
see_memory_usage(f"after copying {total_size} gradients into partition") else:
# Set numel to next partition's offset
#The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer numel = partition_ids_w_offsets[idx + 1][1] - offset
new_grad_tensor = self.grads_in_partition.narrow(0,
self.grads_in_partition_offset, # Merge bucket ranges if they belong to the same rank
param.numel()) if partition_id == prev_id:
new_grad_tensor.copy_(param.grad.view(-1)) prev_pid, prev_size, prev_numel = rank_and_offsets[-1]
param.grad.data = new_grad_tensor.data.view_as(param.grad) rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + numel)
self.grads_in_partition_offset += param.numel() else:
rank_and_offsets.append((partition_id, curr_size, numel))
def reduce_ipg_grads(self):
if self.overlap_comm: curr_size += numel
stream = self.reduction_stream prev_id = partition_id
else: tensor.div_(dist.get_world_size(group=self.dp_process_group))
stream = torch.cuda.current_stream()
async_handles = []
if self.contiguous_gradients: for dst, bucket_offset, numel in rank_and_offsets:
self.average_tensor(self.ipg_buffer[self.ipg_index]) grad_slice = tensor.narrow(0, int(bucket_offset), int(numel))
else: dst_rank = _get_global_rank(self.dp_process_group, dst)
self.buffered_reduce_fallback( async_handle = dist.reduce(grad_slice,
None, dst=dst_rank,
self.grads_in_ipg_bucket, group=self.dp_process_group,
elements_per_buffer=self.elements_in_ipg_bucket) async_op=True)
async_handles.append(async_handle)
with torch.cuda.stream(stream):
for _, param, param_id in self.params_in_ipg_bucket: for handle in async_handles:
self.params_already_reduced[param_id] = True handle.wait()
if not self.is_param_in_current_partition[param_id]: ##############################################################################
if self.overlap_comm and self.contiguous_gradients is False: ############################# CPU Offload Methods#############################
# Clear the previous grads during the next reduction ##############################################################################
# to avoid clearing them before the reduction is complete. def get_grad_position(self, group_id, tensor_list, first_offset, partition_size):
if self.previous_reduced_grads is None: current_offset = 0
self.previous_reduced_grads = []
self.previous_reduced_grads.append(param) for i, tensor in enumerate(tensor_list):
else: param_id = self.get_param_id(tensor)
param.grad = None param_start_offset = 0
elif self.contiguous_gradients:
self.copy_grads_in_partition(param) num_elements = tensor.numel()
tensor_offset = 0
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = [] #we need to offset to get to the right element
self.elements_in_ipg_bucket = 0 if i == 0 and first_offset > 0:
##################################################################### tensor_offset = first_offset
num_elements = num_elements - tensor_offset
def reduce_ready_partitions_and_remove_grads(self, param, i): param_start_offset = first_offset
self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
#we dont need all elements of the tensor
def zero_reduced_gradients(self, partition_id, i): if num_elements > (partition_size - current_offset):
def are_all_related_partitions_reduced(params_id): num_elements = partition_size - current_offset
for partition_id in self.param_to_partition_ids[i][params_id]:
if not self.is_partition_reduced[i][partition_id]: self.grad_position[param_id] = [
return False int(group_id),
return True int(param_start_offset),
int(current_offset),
for params_id in self.is_grad_computed[i][partition_id]: int(num_elements)
if are_all_related_partitions_reduced(params_id): ]
self.param_dict[params_id].grad = None current_offset += num_elements
def flatten_and_print(self, message, tensors, start=0, n=5): def update_overflow_tracker_for_param_grad(self, param):
flatten_tensor = _flatten_dense_tensors(tensors) if param.grad is not None and self._has_inf_or_nan(param.grad.data):
self.local_overflow = True
def print_func():
logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n)) def async_accumulate_grad_in_cpu(self, param):
param_id = self.get_param_id(param)
self.sequential_execution(print_func, message)
#copy to a preexisiting buffer to avoid memory allocation penalty
def get_grads_to_reduce(self, i, partition_id): dest_buffer = self.temp_grad_buffer_for_cpu_offload.view(-1).narrow(
def get_reducable_portion(key): 0,
grad = self.param_dict[key].grad 0,
total_elements = grad.numel() param.numel())
start = self.grad_start_offset[i][partition_id][key] dest_buffer.copy_(param.grad.view(-1), non_blocking=True)
num_elements = min(
total_elements - start, if param_id not in self.accumulated_grads_in_cpu:
self.partition_size[i] - self.accumulated_grads_in_cpu[param_id] = torch.zeros(
self.grad_partition_insertion_offset[i][partition_id][key]) param.numel(),
if not pg_correctness_test: dtype=param.dtype,
if num_elements == total_elements: device=self.device).pin_memory()
return grad
else: self.accumulated_grads_in_cpu[param_id].add_(dest_buffer)
return grad.contiguous().view(-1).narrow(0,
int(start), def async_accumulate_grad_in_cpu_via_gpu(self, param):
int(num_elements)) param_id = self.get_param_id(param)
else:
if num_elements == total_elements: #copy to a preexisiting buffer to avoid memory allocation penalty
return grad.clone() dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(
else: 0,
return grad.clone().contiguous().view(-1).narrow( 0,
0, param.numel())
int(start),
int(num_elements)) if param_id not in self.accumulated_grads_in_cpu:
self.accumulated_grads_in_cpu[param_id] = torch.zeros(
grads_to_reduce = [] param.numel(),
for key in self.is_grad_computed[i][partition_id]: dtype=param.dtype,
grad = get_reducable_portion(key) device=self.device).pin_memory()
grads_to_reduce.append(grad)
return grads_to_reduce if self.micro_step_id > 0:
dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1),
def sequential_execution(self, function, message, group=None): non_blocking=True)
if group is None: param.grad.data.view(-1).add_(dest_buffer)
group = self.dp_process_group
if dist.get_rank(group=group) == 0: #at the boundary we will send 32bit directly
logger.info(message) if not self.is_gradient_accumulation_boundary:
for id in range(dist.get_world_size(group=group)): self.accumulated_grads_in_cpu[param_id].data.copy_(param.grad.data.view(-1),
if id == dist.get_rank(group=group): non_blocking=True)
function()
dist.barrier(group=group) def set_norm_for_param_grad(self, param):
param_id = self.get_param_id(param)
def set_none_gradients_to_zero(self, i, partition_id): accumulated_grad = self.accumulated_grads_in_cpu[
for param_id in self.is_grad_computed[i][partition_id]: param_id] if self.gradient_accumulation_steps > 1 else param.grad
param = self.param_dict[param_id]
if param.grad is None: [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
param.grad = torch.zero_like(param)
start = source_offset
######################Reduction Related Methods############################## accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements)
def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None): self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2)
rank = None
tensor = flatten(bucket) def set_norm_for_param_grad_in_gpu(self, param):
param_id = self.get_param_id(param)
tensor_to_allreduce = tensor accumulated_grad = param.grad
if pg_correctness_test: [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
allreduce_always_fp32 = True
start = source_offset
if allreduce_always_fp32: accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements)
tensor_to_allreduce = tensor.float()
self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2)
tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group))
def async_inplace_copy_grad_to_fp32_buffer(self, param):
if rank is None: param_id = self.get_param_id(param)
# "All Reducing"
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
else:
global_rank = _get_global_rank(self.dp_process_group, rank) dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(
dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group) 0,
dest_offset,
if allreduce_always_fp32 and tensor is not tensor_to_allreduce: num_elements)
if rank is None or rank == dist.get_rank(group=self.dp_process_group): if self.gradient_accumulation_steps > 1:
tensor.copy_(tensor_to_allreduce) src_tensor = self.accumulated_grads_in_cpu[param_id].view(-1).narrow(
0,
return tensor source_offset,
num_elements)
#if rank is specified do a reduction instead of an allreduce else:
def allreduce_and_copy(self, small_bucket, rank=None, log=None): src_tensor = param.grad.view(-1).narrow(0,
if self.overlap_comm: source_offset,
torch.cuda.synchronize() num_elements).float()
if self.previous_reduced_grads is not None: dest_tensor.copy_(src_tensor, non_blocking=True)
# previous_reduced_grads has the previous reduced grads,
# now it is safe to clear. def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param):
for param in self.previous_reduced_grads: param_id = self.get_param_id(param)
param.grad = None
self.previous_reduced_grads = None [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
stream = self.reduction_stream
else: dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(
stream = torch.cuda.current_stream() 0,
dest_offset,
with torch.cuda.stream(stream): num_elements)
allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log)
if rank is None or rank == dist.get_rank(group=self.dp_process_group): src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements).float()
for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)): dest_tensor.copy_(src_tensor, non_blocking=True)
buf.copy_(synced) param.grad = None
def allreduce_no_retain(self, def complete_grad_norm_calculation_for_cpu_offload(self, params):
bucket, total_norm = 0.0
numel_per_bucket=500000000, norm_type = 2.0
rank=None, for p in params:
log=None): if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
small_bucket = [] param_id = self.get_param_id(p)
numel = 0 param_norm = self.norm_for_param_grads[param_id]
for tensor in bucket: total_norm += param_norm.item()**2
small_bucket.append(tensor)
numel = numel + tensor.numel() # Sum across all model parallel GPUs.
if numel > numel_per_bucket: total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
self.allreduce_and_copy(small_bucket, rank=rank, log=None)
small_bucket = [] torch.distributed.all_reduce(total_norm_cuda,
if len(small_bucket) > 0: op=torch.distributed.ReduceOp.SUM,
self.allreduce_and_copy(small_bucket, rank=rank, log=log) group=self.dp_process_group)
#allows using reduction of gradients instead of using all_reduce self._model_parallel_all_reduce(tensor=total_norm_cuda,
def buffered_reduce_fallback(self, op=torch.distributed.ReduceOp.SUM)
rank,
grads, total_norm = total_norm_cuda[0].item()**(1. / norm_type)
elements_per_buffer=500000000,
log=None): if total_norm == float(
split_buckets = split_half_float_double(grads) 'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
for i, bucket in enumerate(split_buckets):
self.allreduce_no_retain(bucket, return total_norm
numel_per_bucket=elements_per_buffer,
rank=rank, ############################################################################################
log=log)
def copy_grads_in_partition(self, param):
############################################################################# if self.cpu_offload:
#############################################################################
############################################################################# if self.gradient_accumulation_steps > 1:
self.async_accumulate_grad_in_cpu_via_gpu(param)
#views the tensor as multiple partitions and returns
#those partitions if self.is_gradient_accumulation_boundary:
def get_data_parallel_partitions(self, tensor): self.set_norm_for_param_grad_in_gpu(param)
partitions = []
self.update_overflow_tracker_for_param_grad(param)
dp = dist.get_world_size(group=self.dp_process_group)
dp_id = dist.get_rank(group=self.dp_process_group) self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
total_num_elements = tensor.numel() return
#print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}")
base_size = total_num_elements // dp if self.grads_in_partition is None:
remaining = total_num_elements % dp self.grads_in_partition_offset = 0
total_size = 0
start = 0 for group in self.params_in_partition:
for id in range(dp): for param_in_partition in group:
partition_size = base_size total_size += param_in_partition.numel()
if id < remaining:
partition_size = partition_size + 1 see_memory_usage(f"before copying {total_size} gradients into partition")
partitions.append(tensor.narrow(0, start, partition_size)) self.grads_in_partition = torch.empty(int(total_size),
start = start + partition_size dtype=torch.half,
return partitions device=torch.cuda.current_device())
see_memory_usage(f"after copying {total_size} gradients into partition")
def get_partition_info(self, tensor_list, partition_size, partition_id):
params_in_partition = [] #The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer
params_not_in_partition = [] new_grad_tensor = self.grads_in_partition.view(-1).narrow(
0,
start_index = partition_size * partition_id self.grads_in_partition_offset,
end_index = partition_size * (partition_id + 1) param.numel())
new_grad_tensor.copy_(param.grad.view(-1))
current_index = 0 param.grad.data = new_grad_tensor.data.view_as(param.grad)
first_offset = 0 #print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}")
self.grads_in_partition_offset += param.numel()
for tensor in tensor_list:
def reduce_ipg_grads(self):
tensor_size = tensor.numel() if self.overlap_comm:
stream = self.reduction_stream
if (current_index >= start_index and current_index < end_index): else:
params_in_partition.append(tensor) stream = torch.cuda.current_stream()
elif start_index > current_index and start_index < (current_index + if self.contiguous_gradients:
tensor_size): self.average_tensor(self.ipg_buffer[self.ipg_index])
params_in_partition.append(tensor) else:
self.buffered_reduce_fallback(
assert (first_offset==0), "This can happen either zero or only once as this must be the first tensor in the partition" None,
first_offset = start_index - current_index self.grads_in_ipg_bucket,
elements_per_buffer=self.elements_in_ipg_bucket)
else:
params_not_in_partition.append(tensor) with torch.cuda.stream(stream):
for _, param, param_id in self.params_in_ipg_bucket:
current_index = current_index + tensor_size self.params_already_reduced[param_id] = True
return params_in_partition, params_not_in_partition, first_offset if not self.is_param_in_current_partition[param_id]:
if self.overlap_comm and self.contiguous_gradients is False:
def zero_grad(self, set_grads_to_None=True): # Clear the previous grads during the next reduction
""" # to avoid clearing them before the reduction is complete.
Zero FP16 parameter grads. if self.previous_reduced_grads is None:
""" self.previous_reduced_grads = []
# FP32 grad should never exist. self.previous_reduced_grads.append(param)
# For speed, set model fp16 grad to None by default else:
for group in self.fp16_groups: param.grad = None
for p in group: elif self.contiguous_gradients:
if set_grads_to_None: self.copy_grads_in_partition(param)
p.grad = None
else: self.grads_in_ipg_bucket = []
if p.grad is not None: self.params_in_ipg_bucket = []
p.grad.detach_() self.elements_in_ipg_bucket = 0
p.grad.zero_() #####################################################################
def _model_parallel_all_reduce(self, tensor, op): def reduce_ready_partitions_and_remove_grads(self, param, i):
""" Perform all reduce within model parallel group, if any. self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
"""
if self.model_parallel_group is None: def zero_reduced_gradients(self, partition_id, i):
torch.distributed.all_reduce(tensor=tensor, op=op) def are_all_related_partitions_reduced(params_id):
else: for partition_id in self.param_to_partition_ids[i][params_id]:
torch.distributed.all_reduce(tensor=tensor, if not self.is_partition_reduced[i][partition_id]:
op=op, return False
group=self.model_parallel_group) return True
def get_grad_norm_direct(self, gradients, params, norm_type=2): for params_id in self.is_grad_computed[i][partition_id]:
"""Clips gradient norm of an iterable of parameters. if are_all_related_partitions_reduced(params_id):
self.param_dict[params_id].grad = None
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that def flatten_and_print(self, message, tensors, start=0, n=5):
the gradients are modified in place. flatten_tensor = _flatten_dense_tensors(tensors)
Arguments: def print_func():
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n))
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients self.sequential_execution(print_func, message)
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm. def get_grads_to_reduce(self, i, partition_id):
def get_reducable_portion(key):
Returns: grad = self.param_dict[key].grad
Total norm of the parameters (viewed as a single vector). total_elements = grad.numel()
""" start = self.grad_start_offset[i][partition_id][key]
norm_type = float(norm_type) num_elements = min(
if norm_type == inf: total_elements - start,
total_norm = max(g.data.abs().max() for g in gradients) self.partition_size[i] -
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) self.grad_partition_insertion_offset[i][partition_id][key])
torch.distributed.all_reduce(total_norm_cuda, if not pg_correctness_test:
op=torch.distributed.ReduceOp.MAX, if num_elements == total_elements:
group=self.dp_process_group) return grad
else:
# Take max across all GPUs. return grad.contiguous().view(-1).narrow(0,
self._model_parallel_all_reduce(tensor=total_norm_cuda, int(start),
op=torch.distributed.ReduceOp.MAX) int(num_elements))
total_norm = total_norm_cuda[0].item() else:
else: if num_elements == total_elements:
total_norm = 0.0 return grad.clone()
#if dist.get_rank() == 0: else:
# logger.info(f"Total Norm begining {total_norm}") return grad.clone().contiguous().view(-1).narrow(
for g, p in zip(gradients, params): 0,
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): int(start),
param_norm = g.data.double().norm(2) int(num_elements))
total_norm += param_norm.item()**2
# Sum across all model parallel GPUs. grads_to_reduce = []
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) for key in self.is_grad_computed[i][partition_id]:
grad = get_reducable_portion(key)
torch.distributed.all_reduce(total_norm_cuda, grads_to_reduce.append(grad)
op=torch.distributed.ReduceOp.SUM, return grads_to_reduce
group=self.dp_process_group)
def sequential_execution(self, function, message, group=None):
self._model_parallel_all_reduce(tensor=total_norm_cuda, if group is None:
op=torch.distributed.ReduceOp.SUM) group = self.dp_process_group
if dist.get_rank(group=group) == 0:
total_norm = total_norm_cuda[0].item()**(1. / norm_type) logger.info(message)
for id in range(dist.get_world_size(group=group)):
if total_norm == float( if id == dist.get_rank(group=group):
'inf') or total_norm == -float('inf') or total_norm != total_norm: function()
total_norm = -1 dist.barrier(group=group)
return total_norm def set_none_gradients_to_zero(self, i, partition_id):
for param_id in self.is_grad_computed[i][partition_id]:
#creates a flat fused tensor from the tensor list starting at the first_offset param = self.param_dict[param_id]
#in the first tensor of the list. If there are not enough elements in the tensor if param.grad is None:
#list then the flat tensor will be padded with zeros param.grad = torch.zero_like(param)
def get_flat_partition(self,
tensor_list, ######################Reduction Related Methods##############################
first_offset,
partition_size, def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None):
dtype, rank = None
device, tensor = flatten(bucket)
return_tensor_list=False):
flat_tensor_list = [] tensor_to_allreduce = tensor
current_size = 0
for i, tensor in enumerate(tensor_list): if pg_correctness_test:
if tensor.grad is None: allreduce_always_fp32 = True
continue
if allreduce_always_fp32:
tensor = tensor.grad tensor_to_allreduce = tensor.float()
num_elements = tensor.numel()
tensor_offset = 0 tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group))
#we need to offset to get to the right element if rank is None:
if i == 0 and first_offset > 0: # "All Reducing"
tensor_offset = first_offset dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
num_elements = num_elements - tensor_offset else:
global_rank = _get_global_rank(self.dp_process_group, rank)
#we dont need all elements of the tensor dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group)
if num_elements > (partition_size - current_size):
num_elements = partition_size - current_size if allreduce_always_fp32 and tensor is not tensor_to_allreduce:
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
#we need a narrow view of the tensor based on the tensor offset and number of elements that tensor.copy_(tensor_to_allreduce)
#we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel(): return tensor
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
0, #if rank is specified do a reduction instead of an allreduce
int(tensor_offset), def allreduce_and_copy(self, small_bucket, rank=None, log=None):
int(num_elements))) if self.overlap_comm:
else: torch.cuda.synchronize()
flat_tensor_list.append(tensor) if self.previous_reduced_grads is not None:
# previous_reduced_grads has the previous reduced grads,
current_size = current_size + num_elements # now it is safe to clear.
for param in self.previous_reduced_grads:
#this means its the last partition and does not align with the dp boundary. We need to pad before flattening param.grad = None
if current_size < partition_size: self.previous_reduced_grads = None
flat_tensor_list.append( stream = self.reduction_stream
torch.zeros(int(partition_size - current_size), else:
dtype=dtype, stream = torch.cuda.current_stream()
device=device))
with torch.cuda.stream(stream):
if return_tensor_list: allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log)
return flat_tensor_list if rank is None or rank == dist.get_rank(group=self.dp_process_group):
for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)):
return _flatten_dense_tensors(flat_tensor_list) buf.copy_(synced)
def free_grad_in_param_list(self, param_list): def allreduce_no_retain(self,
for p in param_list: bucket,
p.grad = None numel_per_bucket=500000000,
rank=None,
def step(self, closure=None): log=None):
""" small_bucket = []
Not supporting closure. numel = 0
""" for tensor in bucket:
see_memory_usage(f"In step before checking overflow") small_bucket.append(tensor)
numel = numel + tensor.numel()
# First compute norm for all group so we know if there is overflow if numel > numel_per_bucket:
self.check_overflow() self.allreduce_and_copy(small_bucket, rank=rank, log=None)
small_bucket = []
timers = self.timers if len(small_bucket) > 0:
self.allreduce_and_copy(small_bucket, rank=rank, log=log)
prev_scale = self.loss_scale
self._update_scale(self.overflow) #allows using reduction of gradients instead of using all_reduce
if self.overflow: def buffered_reduce_fallback(self,
see_memory_usage('After overflow before clearing gradients') rank,
self.zero_grad() grads,
for key in self.averaged_gradients: elements_per_buffer=500000000,
self.averaged_gradients[key] = None log=None):
split_buckets = split_half_float_double(grads)
see_memory_usage('After overflow after clearing gradients')
for i, bucket in enumerate(split_buckets):
logger.info( self.allreduce_no_retain(bucket,
"[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, " numel_per_bucket=elements_per_buffer,
"reducing to {}".format(dist.get_rank(), rank=rank,
prev_scale, log=log)
self.loss_scale))
timers('optimizer_step').start() #############################################################################
timers('optimizer_step').stop() #############################################################################
timers('optimizer_allgather').start() #############################################################################
timers('optimizer_allgather').stop()
return #views the tensor as multiple partitions and returns
#those partitions
norm_groups = [] def get_data_parallel_partitions(self, tensor):
single_partition_grad_groups = [] partitions = []
skip = False
partition_id = dist.get_rank(group=self.dp_process_group) dp = dist.get_world_size(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups): dp_id = dist.get_rank(group=self.dp_process_group)
norm_groups.append( total_num_elements = tensor.numel()
self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i])) base_size = total_num_elements // dp
remaining = total_num_elements % dp
#free gradients for all the prameters that are not updated by this process
self.free_grad_in_param_list(self.params_not_in_partition[i]) start = 0
for id in range(dp):
#create a flat gradients for parameters updated by this process partition_size = base_size
# If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors if id < remaining:
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: partition_size = partition_size + 1
single_grad_partition = flatten_dense_tensors_aligned( partitions.append(tensor.narrow(0, start, partition_size))
self.averaged_gradients[i], start = start + partition_size
int(self.partition_size[i])).to( return partitions
self.single_partition_of_fp32_groups[i].dtype)
else: def get_partition_info(self, tensor_list, partition_size, partition_id):
single_grad_partition = _flatten_dense_tensors( params_in_partition = []
self.averaged_gradients[i]).to( params_not_in_partition = []
self.single_partition_of_fp32_groups[i].dtype)
assert single_grad_partition.numel() == self.partition_size[i], \ start_index = partition_size * partition_id
"averaged gradients have different number of elements that partition size {} {} {} {}".format(single_grad_partition.numel(), self.partition_size[i], i, partition_id) end_index = partition_size * (partition_id + 1)
self.single_partition_of_fp32_groups[i].grad = single_grad_partition current_index = 0
#release all the gradient since we have already created a necessary copy in dp_grad_partition first_offset = 0
self.free_grad_in_param_list(self.params_in_partition[i])
for tensor in tensor_list:
self.averaged_gradients[i] = None
tensor_size = tensor.numel()
single_partition_grad_groups.append(single_grad_partition)
if (current_index >= start_index and current_index < end_index):
self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups) params_in_partition.append(tensor)
timers('optimizer_step').start() elif start_index > current_index and start_index < (current_index +
self.optimizer.step() tensor_size):
#get rid of the fp32 gradients. Not needed anymore params_in_partition.append(tensor)
for group in self.single_partition_of_fp32_groups:
group.grad = None assert (first_offset==0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data) else:
timers('optimizer_step').stop() params_not_in_partition.append(tensor)
timers('optimizer_allgather').start() current_index = current_index + tensor_size
#gather the updated weights from everyone
for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups): return params_in_partition, params_not_in_partition, first_offset
#Sequential AllGather Best of both worlds def zero_grad(self, set_grads_to_None=True):
dp_world_size = dist.get_world_size(group=self.dp_process_group) """
num_shards = max( Zero FP16 parameter grads.
1, """
partitioned_params[partition_id].numel() * dp_world_size // # FP32 grad should never exist.
self.allgather_bucket_size) # For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
shard_size = partitioned_params[partition_id].numel() // num_shards for p in group:
num_elements = shard_size if set_grads_to_None:
p.grad = None
assert shard_size * num_shards <= partitioned_params[partition_id].numel() else:
if p.grad is not None:
for shard_id in range(num_shards): p.grad.detach_()
p.grad.zero_()
if shard_id == (num_shards - 1):
num_elements = partitioned_params[partition_id].numel( def _model_parallel_all_reduce(self, tensor, op):
) - shard_id * shard_size """ Perform all reduce within model parallel group, if any.
"""
shard_list = [] if self.model_parallel_group is None:
for dp_id in range(dp_world_size): torch.distributed.all_reduce(tensor=tensor, op=op)
curr_shard = partitioned_params[dp_id].narrow( else:
0, torch.distributed.all_reduce(tensor=tensor,
shard_id * shard_size, op=op,
num_elements).detach() group=self.model_parallel_group)
shard_list.append(curr_shard)
def get_grad_norm_direct(self, gradients, params, norm_type=2):
dist.all_gather(shard_list, """Clips gradient norm of an iterable of parameters.
shard_list[partition_id],
group=self.dp_process_group) This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
timers('optimizer_allgather').stop() added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)): Arguments:
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
self.fp16_groups[i]) single Tensor that will have gradients normalized
for p, q in zip(self.fp16_groups[i], updated_params): max_norm (float or int): max norm of the gradients
p.data = q.data norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
see_memory_usage('After zero_optimizer step')
return Returns:
Total norm of the parameters (viewed as a single vector).
def unscale_and_clip_grads(self, grad_groups_flat, norm_groups): """
total_norm = 0.0 norm_type = float(norm_type)
for norm in norm_groups: if norm_type == inf:
total_norm += norm**2.0 total_norm = max(g.data.abs().max() for g in gradients)
total_norm = math.sqrt(total_norm) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
# compute combined scale factor for this group op=torch.distributed.ReduceOp.MAX,
combined_scale = self.loss_scale group=self.dp_process_group)
if self.clip_grad > 0.:
# norm is in fact norm*scale # Take max across all GPUs.
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad self._model_parallel_all_reduce(tensor=total_norm_cuda,
if clip > 1: op=torch.distributed.ReduceOp.MAX)
combined_scale = clip * self.loss_scale total_norm = total_norm_cuda[0].item()
else:
for grad in grad_groups_flat: total_norm = 0.0
if isinstance(grad, list): #if dist.get_rank() == 0:
sub_partitions = grad # logger.info(f"Total Norm begining {total_norm}")
for g in sub_partitions: for g, p in zip(gradients, params):
g.data.mul_(1. / combined_scale) if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
else: param_norm = g.data.double().norm(2)
grad.data.mul_(1. / combined_scale) total_norm += param_norm.item()**2
# Sum across all model parallel GPUs.
def _check_overflow(self, partition_gradients=True): total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
self.overflow = self.has_overflow(partition_gradients)
torch.distributed.all_reduce(total_norm_cuda,
# `params` is a list / generator of torch.Variable op=torch.distributed.ReduceOp.SUM,
def has_overflow_serial(self, params, is_grad_list=False): group=self.dp_process_group)
for p in params:
if p.grad is not None and self._has_inf_or_nan(p.grad.data): self._model_parallel_all_reduce(tensor=total_norm_cuda,
return True op=torch.distributed.ReduceOp.SUM)
return False total_norm = total_norm_cuda[0].item()**(1. / norm_type)
def has_overflow_partitioned_grads_serial(self): if total_norm == float(
for i in range(len(self.fp16_groups)): 'inf') or total_norm == -float('inf') or total_norm != total_norm:
for j, grad in enumerate(self.averaged_gradients[i]): total_norm = -1
if grad is not None and self._has_inf_or_nan(grad.data, j):
return True return total_norm
return False
#creates a flat fused tensor from the tensor list starting at the first_offset
def has_overflow(self, partition_gradients=True): #in the first tensor of the list. If there are not enough elements in the tensor
if partition_gradients: #list then the flat tensor will be padded with zeros
overflow = self.has_overflow_partitioned_grads_serial() def get_flat_partition(self,
overflow_gpu = torch.cuda.ByteTensor([overflow]) tensor_list,
torch.distributed.all_reduce(overflow_gpu, first_offset,
op=torch.distributed.ReduceOp.MAX, partition_size,
group=self.dp_process_group) dtype,
device,
else: return_tensor_list=False):
params = [] flat_tensor_list = []
for group in self.fp16_groups: current_size = 0
for param in group: for i, tensor in enumerate(tensor_list):
params.append(param) if tensor.grad is None:
continue
overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients)
overflow_gpu = torch.cuda.ByteTensor([overflow]) tensor = tensor.grad
num_elements = tensor.numel()
# Since each model parallel GPU carries only part of the model, tensor_offset = 0
# make sure overflow flag is synced across all the model parallel GPUs
self._model_parallel_all_reduce(tensor=overflow_gpu, #we need to offset to get to the right element
op=torch.distributed.ReduceOp.MAX) if i == 0 and first_offset > 0:
tensor_offset = first_offset
overflow = overflow_gpu[0].item() num_elements = num_elements - tensor_offset
return bool(overflow)
#we dont need all elements of the tensor
# `x` is a torch.Tensor if num_elements > (partition_size - current_size):
@staticmethod num_elements = partition_size - current_size
def _has_inf_or_nan(x, j=None):
try: #we need a narrow view of the tensor based on the tensor offset and number of elements that
# if x is half, the .float() incurs an additional deep copy, but it's necessary if #we need from this tensor
# Pytorch's .sum() creates a one-element tensor of the same type as x if tensor_offset > 0 or num_elements < tensor.numel():
# (which is true for some recent version of pytorch). flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
cpu_sum = float(x.float().sum()) 0,
# More efficient version that can be used if .sum() returns a Python scalar int(tensor_offset),
# cpu_sum = float(x.sum()) int(num_elements)))
except RuntimeError as instance: else:
# We want to check if inst is actually an overflow exception. flat_tensor_list.append(tensor)
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate. current_size = current_size + num_elements
if "value cannot be converted" not in instance.args[0]:
raise #this means its the last partition and does not align with the dp boundary. We need to pad before flattening
return True if current_size < partition_size:
else: flat_tensor_list.append(
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: torch.zeros(int(partition_size - current_size),
return True dtype=dtype,
return False device=device))
def backward(self, loss, retain_graph=False): if return_tensor_list:
""" return flat_tensor_list
:attr:`backward` performs the following steps:
return _flatten_dense_tensors(flat_tensor_list)
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale def free_grad_in_param_list(self, param_list):
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves for p in param_list:
""" p.grad = None
if self.contiguous_gradients:
self.ipg_buffer = [] def reset_cpu_buffers(self):
buf_0 = torch.empty(self.reduce_bucket_size, self.norm_for_param_grads = {}
dtype=torch.half, self.local_overflow = False
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_0) def step(self, closure=None):
"""
# Use double buffers to avoid data access conflict when overlap_comm is enabled. Not supporting closure.
if self.overlap_comm: """
buf_1 = torch.empty(self.reduce_bucket_size, self.micro_step_id = -1
dtype=torch.half,
device=torch.cuda.current_device()) if self.cpu_offload:
self.ipg_buffer.append(buf_1) torch.cuda.current_stream().wait_stream(self.migration_stream)
self.ipg_index = 0
see_memory_usage(f"In step before checking overflow")
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
# First compute norm for all group so we know if there is overflow
def check_overflow(self, partition_gradients=True): self.check_overflow()
self._check_overflow(partition_gradients)
timers = self.timers
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow) prev_scale = self.loss_scale
self._update_scale(self.overflow)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" if self.overflow:
def _get_state(self): see_memory_usage('After overflow before clearing gradients')
return self.optimizer.state self.zero_grad()
if self.cpu_offload:
def _set_state(self, value): self.reset_cpu_buffers()
self.optimizer.state = value else:
self.averaged_gradients = {}
state = property(_get_state, _set_state)
see_memory_usage('After overflow after clearing gradients')
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate) logger.info(
def _get_param_groups(self): "[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
return self.optimizer.param_groups "reducing to {}".format(dist.get_rank(),
prev_scale,
def _set_param_groups(self, value): self.loss_scale))
self.optimizer.param_groups = value timers('optimizer_gradients').start()
timers('optimizer_gradients').stop()
param_groups = property(_get_param_groups, _set_param_groups) timers('optimizer_step').start()
timers('optimizer_step').stop()
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" timers('optimizer_allgather').start()
def _get_loss_scale(self): timers('optimizer_allgather').stop()
return self.loss_scaler.loss_scale return
def _set_loss_scale(self, value): timers('optimizer_gradients').start()
self.loss_scaler.cur_scale = value norm_groups = []
single_partition_grad_groups = []
loss_scale = property(_get_loss_scale, _set_loss_scale) skip = False
cur_scale = property(_get_loss_scale, _set_loss_scale) partition_id = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups):
# Return group tensor after removing paddings that are added for alignment to DP world size. if self.cpu_offload:
# This method works on the assumption that each group contains a single flattened tensor. norm_groups.append(
def _get_groups_without_padding(self, groups_with_padding): self.complete_grad_norm_calculation_for_cpu_offload(
groups_without_padding = [] self.params_in_partition[i]))
for i, group in enumerate(groups_with_padding): single_grad_partition = self.single_partition_of_fp32_groups[i].grad
lean_length = group.numel() - self.groups_padding[i] else:
groups_without_padding.append(group[:lean_length]) norm_groups.append(
self.get_grad_norm_direct(self.averaged_gradients[i],
return groups_without_padding self.params_in_partition[i]))
# Return optimizer state after removing paddings that are added for alignment. #free gradients for all the prameters that are not updated by this process
def _get_state_without_padding(self, state_with_padding, padding): self.free_grad_in_param_list(self.params_not_in_partition[i])
lean_state = {}
for key, value in state_with_padding.items(): #create a flat gradients for parameters updated by this process
lean_length = value.numel() - padding # If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors
lean_state[key] = value[:lean_length] if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
single_grad_partition = flatten_dense_tensors_aligned(
return lean_state self.averaged_gradients[i],
int(self.partition_size[i])).to(
# Return base optimizer states. self.single_partition_of_fp32_groups[i].dtype)
# This method assumes that each param group contains a single flattened tensor. else:
def _get_base_optimizer_state(self): single_grad_partition = _flatten_dense_tensors(
optimizer_groups_state = [] self.averaged_gradients[i]).to(
for i, group in enumerate(self.optimizer.param_groups): self.single_partition_of_fp32_groups[i].dtype)
p = group['params'][0] assert single_grad_partition.numel() == self.partition_size[i], \
lean_optimizer_state = self._get_state_without_padding( "averaged gradients have different number of elements that partition size {} {} {} {}".format(single_grad_partition.numel(), self.partition_size[i], i, partition_id)
self.optimizer.state[p],
self.groups_padding[i]) self.single_partition_of_fp32_groups[i].grad = single_grad_partition
optimizer_groups_state.append(lean_optimizer_state) #release all the gradient since we have already created a necessary copy in dp_grad_partition
self.free_grad_in_param_list(self.params_in_partition[i])
return optimizer_groups_state
self.averaged_gradients[i] = None
def state_dict(self):
""" single_partition_grad_groups.append(single_grad_partition)
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
of the contained Pytorch optimizer. timers('optimizer_gradients').stop()
Example::
checkpoint = {} #torch.set_num_threads(12)
checkpoint['model'] = model.state_dict() timers('optimizer_step').start()
checkpoint['optimizer'] = optimizer.state_dict() if self.deepspeed_adam_offload:
torch.save(checkpoint, "saved.pth") self.optimizer.step(fp16_param_groups=self.parallel_partitioned_fp16_groups)
""" #self.optimizer.step()
state_dict = {} #for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
state_dict['loss_scaler'] = self.loss_scaler # fp16_partitions[partition_id].data.copy_(fp32_partition.data)
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale else:
state_dict['overflow'] = self.overflow self.optimizer.step()
state_dict['base_optimizer_state'] = self._get_base_optimizer_state()
#get rid of the fp32 gradients. Not needed anymore
state_dict['zero_stage'] = ZERO_OPTIMIZATION_GRADIENTS if not self.cpu_offload:
state_dict['partition_count'] = self.partition_count for group in self.single_partition_of_fp32_groups:
group.grad = None
# Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding( for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
self.single_partition_of_fp32_groups) fp16_partitions[partition_id].data.copy_(fp32_partition.data)
state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding
timers('optimizer_step').stop()
return state_dict
if self.cpu_offload:
# Restore base optimizer fp32 weights from checkpoint by: self.reset_cpu_buffers()
# 1) Merging fp32 weights from checkpoints of all partitions
# 2) Extracting fp32 weights for current partition from merged weights timers('optimizer_allgather').start()
# 3) Using extracted weights to update base optimizer weights directly. #gather the updated weights from everyone
def _restore_from_fp32_weights(self, all_state_dict): for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups):
partition_id = dist.get_rank(group=self.dp_process_group)
merged_single_partition_of_fp32_groups = [] #Sequential AllGather Best of both worlds
for i in range(len(self.single_partition_of_fp32_groups)): dp_world_size = dist.get_world_size(group=self.dp_process_group)
merged_partitions = [ num_shards = max(
sd['single_partition_of_fp32_groups'][i] for sd in all_state_dict 1,
] partitioned_params[partition_id].numel() * dp_world_size //
flat_merged_partitions = flatten_dense_tensors_aligned( self.allgather_bucket_size)
merged_partitions,
dist.get_world_size(group=self.dp_process_group)) shard_size = partitioned_params[partition_id].numel() // num_shards
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions) num_elements = shard_size
merged_single_partition_of_fp32_groups.append(dp_partitions[partition_id])
assert shard_size * num_shards <= partitioned_params[partition_id].numel()
for current, saved in zip(self.single_partition_of_fp32_groups, merged_single_partition_of_fp32_groups):
current.data.copy_(saved.data) for shard_id in range(num_shards):
# Restore base optimizer fp32 weights from ZeRO fp16 weights if shard_id == (num_shards - 1):
def _restore_from_fp16_weights(self): num_elements = partitioned_params[partition_id].numel(
partition_id = dist.get_rank(group=self.dp_process_group) ) - shard_id * shard_size
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp32_partition.data.copy_(fp16_partitions[partition_id].data) shard_list = []
for dp_id in range(dp_world_size):
# Refresh the fp32 master params from the fp16 copies. curr_shard = partitioned_params[dp_id].narrow(
def refresh_fp32_params(self): 0,
self._restore_from_fp16_weights() shard_id * shard_size,
num_elements).detach()
# Extract optimizer state for current partition from merged states of all partitions shard_list.append(curr_shard)
def _partition_base_optimizer_state(self, state_key, all_partition_states):
partition_id = dist.get_rank(group=self.dp_process_group) dist.all_gather(shard_list,
alignment = dist.get_world_size(group=self.dp_process_group) shard_list[partition_id],
flat_merged_partitions = flatten_dense_tensors_aligned( group=self.dp_process_group)
all_partition_states, timers('optimizer_allgather').stop()
alignment)
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions) # TODO: we probably don't need this? just to be safe
return dp_partitions[partition_id] for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
# Restore base optimizer state from checkpoint by self.fp16_groups[i])
# 1) Merging optimizer state from checkpoints of all partitions for p, q in zip(self.fp16_groups[i], updated_params):
# 2) Extracting optimizer state for current partition from the merged state p.data = q.data
# 3) Using the extracted value to directly update the base optimizer.
def _restore_base_optimizer_state(self, all_state_dict): timers.log(
base_optimizer_group_states = [] names=['optimizer_gradients',
for i in range(len(self.optimizer.param_groups)): 'optimizer_step',
partition_states = {} 'optimizer_allgather'])
all_partition_group_states = [ see_memory_usage('After zero_optimizer step')
sd['base_optimizer_state'][i] for sd in all_state_dict return
]
for key in all_partition_group_states[0].keys(): def unscale_and_clip_grads(self, grad_groups_flat, norm_groups):
all_partition_states = [ total_norm = 0.0
all_states[key] for all_states in all_partition_group_states for norm in norm_groups:
] total_norm += norm**2.0
partition_states[key] = self._partition_base_optimizer_state( total_norm = math.sqrt(total_norm)
key,
all_partition_states) # compute combined scale factor for this group
base_optimizer_group_states.append(partition_states) combined_scale = self.loss_scale
if self.clip_grad > 0.:
for i, group in enumerate(self.optimizer.param_groups): # norm is in fact norm*scale
p = group['params'][0] clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
for key, saved in base_optimizer_group_states[i].items(): if clip > 1:
current = self.optimizer.state[p][key] combined_scale = clip * self.loss_scale
current.data.copy_(saved.data)
for grad in grad_groups_flat:
def load_state_dict(self, if isinstance(grad, list):
state_dict_list, sub_partitions = grad
load_optimizer_states=True, for g in sub_partitions:
load_from_fp32_weights=False): g.data.mul_(1. / combined_scale)
r"""Loading ZeRO checkpoint else:
grad.data.mul_(1. / combined_scale)
Arguments:
state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. def _check_overflow(self, partition_gradients=True):
Note that the number of saved partitions may differ from number of loading partitions to support self.overflow = self.has_overflow(partition_gradients)
changing GPU count, specifically DP world size, between saving and loading checkpoints.
load_optimizer_states: Boolean indicating whether or not to load base optimizer states # `params` is a list / generator of torch.Variable
load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32 def has_overflow_serial(self, params, is_grad_list=False):
copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss). for p in params:
""" if p.grad is not None and self._has_inf_or_nan(p.grad.data):
""" return True
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, return False
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before def has_overflow_partitioned_grads_serial(self):
``fp16_optimizer_instance.load_state_dict()`` is called. for i in range(len(self.fp16_groups)):
Example:: for j, grad in enumerate(self.averaged_gradients[i]):
model = torch.nn.Linear(D_in, D_out).cuda().half() if grad is not None and self._has_inf_or_nan(grad.data, j):
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) return True
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) return False
...
checkpoint = torch.load("saved.pth") def has_overflow(self, partition_gradients=True):
model.load_state_dict(checkpoint['model']) if partition_gradients:
optimizer.load_state_dict(checkpoint['optimizer']) overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial(
""" )
# I think it should actually be ok to reload the optimizer before the model. overflow_gpu = torch.cuda.ByteTensor([overflow])
self.loss_scaler = state_dict_list[0]['loss_scaler'] torch.distributed.all_reduce(overflow_gpu,
self.dynamic_loss_scale = state_dict_list[0]['dynamic_loss_scale'] op=torch.distributed.ReduceOp.MAX,
self.overflow = state_dict_list[0]['overflow'] group=self.dp_process_group)
if load_optimizer_states: else:
self._restore_base_optimizer_state(state_dict_list) params = []
for group in self.fp16_groups:
# At this point, the optimizer's references to the model's fp32 parameters are up to date. for param in group:
# The optimizer's hyperparameters and internal buffers are also up to date. params.append(param)
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options. overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients)
# 1: Refresh the master params from the model's fp16 params. overflow_gpu = torch.cuda.ByteTensor([overflow])
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately. # Since each model parallel GPU carries only part of the model,
# We choose option 1 if changing DP degree and option 2 otherwise. # make sure overflow flag is synced across all the model parallel GPUs
# self._model_parallel_all_reduce(tensor=overflow_gpu,
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device op=torch.distributed.ReduceOp.MAX)
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been overflow = overflow_gpu[0].item()
# constructed in the same way as the one whose state_dict we are loading, the same master params return bool(overflow)
# are guaranteed to exist, so we can just copy_() from the saved master params.
# `x` is a torch.Tensor
if load_from_fp32_weights: @staticmethod
self._restore_from_fp32_weights(state_dict_list) def _has_inf_or_nan(x, j=None):
else: try:
self._restore_from_fp16_weights() # if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
def _handle_overflow(cpu_sum, x, i): cpu_sum = float(x.float().sum())
import math # More efficient version that can be used if .sum() returns a Python scalar
rank = torch.distributed.get_rank() # cpu_sum = float(x.sum())
if rank == 0: except RuntimeError as instance:
t_i = -1 # We want to check if inst is actually an overflow exception.
for v_i, v in enumerate(x.data.contiguous().view(-1)): # RuntimeError could come from a different error.
if not math.isfinite(float(v)): # If so, we still want the exception to propagate.
t_i = v_i if "value cannot be converted" not in instance.args[0]:
break raise
logger.info( return True
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}" else:
) if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
return False
def backward(self, loss, retain_graph=False):
"""
:attr:`backward` performs the following steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
self.micro_step_id += 1
if self.cpu_offload:
torch.cuda.current_stream().wait_stream(self.migration_stream)
if self.contiguous_gradients:
self.ipg_buffer = []
buf_0 = torch.empty(int(self.reduce_bucket_size * 4.5),
dtype=torch.half,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_0)
# Use double buffers to avoid data access conflict when overlap_comm is enabled.
if self.overlap_comm:
buf_1 = torch.empty(int(self.reduce_bucket_size * 4.5),
dtype=torch.half,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_1)
self.ipg_index = 0
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
def check_overflow(self, partition_gradients=True):
self._check_overflow(partition_gradients)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
# Return group tensor after removing paddings that are added for alignment to DP world size.
# This method works on the assumption that each group contains a single flattened tensor.
def _get_groups_without_padding(self, groups_with_padding):
groups_without_padding = []
for i, group in enumerate(groups_with_padding):
lean_length = group.numel() - self.groups_padding[i]
groups_without_padding.append(group[:lean_length])
return groups_without_padding
# Return optimizer state after removing paddings that are added for alignment.
def _get_state_without_padding(self, state_with_padding, padding):
lean_state = {}
for key, value in state_with_padding.items():
if torch.is_tensor(value):
lean_length = value.numel() - padding
lean_state[key] = value[:lean_length]
else:
lean_state[key] = value
return lean_state
# Return base optimizer states.
# This method assumes that each param group contains a single flattened tensor.
def _get_base_optimizer_state(self):
optimizer_groups_state = []
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
lean_optimizer_state = self._get_state_without_padding(
self.optimizer.state[p],
self.groups_padding[i])
optimizer_groups_state.append(lean_optimizer_state)
return optimizer_groups_state
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['base_optimizer_state'] = self._get_base_optimizer_state()
state_dict['zero_stage'] = ZERO_OPTIMIZATION_GRADIENTS
state_dict['partition_count'] = self.partition_count
# Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding(
self.single_partition_of_fp32_groups)
state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding
# if self.cpu_offload:
# state_dict_tmp = async_copy_to(state_dict,
# 'cpu',
# torch.cuda.current_stream())
# state_dict = state_dict_tmp
return state_dict
# Restore base optimizer fp32 weights from checkpoint by:
# 1) Merging fp32 weights from checkpoints of all partitions
# 2) Extracting fp32 weights for current partition from merged weights
# 3) Using extracted weights to update base optimizer weights directly.
def _restore_from_fp32_weights(self, all_state_dict):
partition_id = dist.get_rank(group=self.dp_process_group)
merged_single_partition_of_fp32_groups = []
for i in range(len(self.single_partition_of_fp32_groups)):
merged_partitions = [
sd['single_partition_of_fp32_groups'][i] for sd in all_state_dict
]
flat_merged_partitions = flatten_dense_tensors_aligned(
merged_partitions,
dist.get_world_size(group=self.dp_process_group))
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions)
merged_single_partition_of_fp32_groups.append(dp_partitions[partition_id])
for current, saved in zip(self.single_partition_of_fp32_groups, merged_single_partition_of_fp32_groups):
current.data.copy_(saved.data)
# Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self):
partition_id = dist.get_rank(group=self.dp_process_group)
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp32_partition.data.copy_(fp16_partitions[partition_id].data)
# Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
self._restore_from_fp16_weights()
# Extract optimizer state for current partition from merged states of all partitions
def _partition_base_optimizer_state(self, state_key, all_partition_states):
partition_id = dist.get_rank(group=self.dp_process_group)
alignment = dist.get_world_size(group=self.dp_process_group)
if torch.is_tensor(all_partition_states[0]):
flat_merged_partitions = flatten_dense_tensors_aligned(
all_partition_states,
alignment)
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions)
return dp_partitions[partition_id]
else:
# Assume non-tensor states are not partitioned and equal across ranks, so return first one
return all_partition_states[0]
# Restore base optimizer state from checkpoint by
# 1) Merging optimizer state from checkpoints of all partitions
# 2) Extracting optimizer state for current partition from the merged state
# 3) Using the extracted value to directly update the base optimizer.
def _restore_base_optimizer_state(self, all_state_dict):
base_optimizer_group_states = []
for i in range(len(self.optimizer.param_groups)):
partition_states = {}
all_partition_group_states = [
sd['base_optimizer_state'][i] for sd in all_state_dict
]
for key in all_partition_group_states[0].keys():
all_partition_states = [
all_states[key] for all_states in all_partition_group_states
]
partition_states[key] = self._partition_base_optimizer_state(
key,
all_partition_states)
base_optimizer_group_states.append(partition_states)
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
for key, saved in base_optimizer_group_states[i].items():
if torch.is_tensor(self.optimizer.state[p][key]):
self.optimizer.state[p][key].data.copy_(saved.data)
else:
self.optimizer.state[p][key] = saved
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
r"""Loading ZeRO checkpoint
Arguments:
state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
Note that the number of saved partitions may differ from number of loading partitions to support
changing GPU count, specifically DP world size, between saving and loading checkpoints.
load_optimizer_states: Boolean indicating whether or not to load base optimizer states
load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32
copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss).
"""
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict_list[0]['loss_scaler']
self.dynamic_loss_scale = state_dict_list[0]['dynamic_loss_scale']
self.overflow = state_dict_list[0]['overflow']
if load_optimizer_states:
self._restore_base_optimizer_state(state_dict_list)
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 1 if changing DP degree and option 2 otherwise.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
if load_from_fp32_weights:
self._restore_from_fp32_weights(state_dict_list)
else:
self._restore_from_fp16_weights()
def _handle_overflow(cpu_sum, x, i):
import math
rank = torch.distributed.get_rank()
if rank == 0:
t_i = -1
for v_i, v in enumerate(x.data.contiguous().view(-1)):
if not math.isfinite(float(v)):
t_i = v_i
break
logger.info(
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
)
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import apex
from deepspeed.utils import logger from deepspeed.utils import logger
from deepspeed.ops.adam import DeepSpeedCPUAdam
def _initialize_parameter_parallel_groups(parameter_parallel_size=None): def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
...@@ -20,3 +21,17 @@ def _initialize_parameter_parallel_groups(parameter_parallel_size=None): ...@@ -20,3 +21,17 @@ def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
if rank in ranks: if rank in ranks:
my_group = group my_group = group
return my_group return my_group
ZERO_SUPPORTED_OPTIMIZERS = [
torch.optim.Adam,
apex.optimizers.FusedAdam,
DeepSpeedCPUAdam
]
def is_zero_supported_optimizer(optimizer):
print(
f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}'
)
return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS
...@@ -162,6 +162,14 @@ Please see the [core API doc](https://deepspeed.readthedocs.io/) for more detail ...@@ -162,6 +162,14 @@ Please see the [core API doc](https://deepspeed.readthedocs.io/) for more detail
With DeepSpeed, the user can choose to use a high performance implementation of ADAM from With DeepSpeed, the user can choose to use a high performance implementation of ADAM from
NVIDIA, or any training optimizer that extends torch's `torch.optim.Optimizer` class. NVIDIA, or any training optimizer that extends torch's `torch.optim.Optimizer` class.
### CPU-Adam: High-Performance vectorized implementation of Adam
We introduce an efficient implementation of Adam optimizer on CPU that improves the parameter-update
performance by nearly an order of magnitude. We use the AVX SIMD instructions on Intel-x86 architecture
for the CPU-Adam implementation. We support both AVX-512 and AVX-2 instruction sets. DeepSpeed uses
AVX-2 by defualt which can be switched to AVX-512 by setting the build flag, `DS_BUILD_AVX512` to 1 when
installing DeepSpeed. Using AVX-512, we observe 5.1x to 6.5x speedups considering the model-size between
1 to 10 billion parameters with respect to torch-adam.
### Memory bandwidth optimized FP16 Optimizer ### Memory bandwidth optimized FP16 Optimizer
Mixed precision training is handled by the DeepSpeed FP16 Optimizer. This optimizer not Mixed precision training is handled by the DeepSpeed FP16 Optimizer. This optimizer not
only handles FP16 training but is also highly efficient. The performance of weight update only handles FP16 training but is also highly efficient. The performance of weight update
......
...@@ -167,6 +167,7 @@ overview](/features/) for descriptions and usage. ...@@ -167,6 +167,7 @@ overview](/features/) for descriptions and usage.
* Automatic loss scaling with mixed precision * Automatic loss scaling with mixed precision
* [Training Optimizers](/features/#training-optimizers) * [Training Optimizers](/features/#training-optimizers)
* Fused Adam optimizer and arbitrary `torch.optim.Optimizer` * Fused Adam optimizer and arbitrary `torch.optim.Optimizer`
* CPU-Adam: High-Performance vectorized Adam
* Memory bandwidth optimized FP16 Optimizer * Memory bandwidth optimized FP16 Optimizer
* Large Batch Training with LAMB Optimizer * Large Batch Training with LAMB Optimizer
* Memory efficient Training with ZeRO Optimizer * Memory efficient Training with ZeRO Optimizer
......
...@@ -164,10 +164,10 @@ if [ ! -f $hostfile ]; then ...@@ -164,10 +164,10 @@ if [ ! -f $hostfile ]; then
local_only=1 local_only=1
fi fi
#if [ "$skip_requirements" == "0" ]; then if [ "$skip_requirements" == "0" ]; then
# # Ensure dependencies are installed locally # Ensure dependencies are installed locally
# $PIP_SUDO $PIP_INSTALL -r requirements.txt $PIP_SUDO $PIP_INSTALL -r requirements/requirements.txt
#fi fi
# Build wheels # Build wheels
if [ "$third_party_install" == "1" ]; then if [ "$third_party_install" == "1" ]; then
...@@ -220,10 +220,10 @@ else ...@@ -220,10 +220,10 @@ else
tmp_wheel_path="/tmp/deepspeed_wheels" tmp_wheel_path="/tmp/deepspeed_wheels"
pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; else mkdir -pv $tmp_wheel_path; fi" pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; else mkdir -pv $tmp_wheel_path; fi"
#pdcp -w $hosts requirements/*.txt ${tmp_wheel_path}/ pdcp -w $hosts requirements/requirements.txt ${tmp_wheel_path}/
#if [ "$skip_requirements" == "0" ]; then if [ "$skip_requirements" == "0" ]; then
# pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL -r ${tmp_wheel_path}/requirements.txt" pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL -r ${tmp_wheel_path}/requirements.txt"
#fi fi
if [ "$third_party_install" == "1" ]; then if [ "$third_party_install" == "1" ]; then
pdsh -w $hosts "$PIP_SUDO pip uninstall -y apex" pdsh -w $hosts "$PIP_SUDO pip uninstall -y apex"
pdcp -w $hosts third_party/apex/dist/apex*.whl $tmp_wheel_path/ pdcp -w $hosts third_party/apex/dist/apex*.whl $tmp_wheel_path/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment