Unverified Commit 41db1c2f authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files
parent 79093d74
#include "cpu_adam.h"
#include <cuda_runtime_api.h>
#include <math.h>
#include <omp.h>
#include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#include "custom_cuda_layers.h"
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
// C++ interface
void Adam_Optimizer::Step(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params)
{
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float bias_correction1 = 1 - _betta1_t;
float bias_correction2 = 1 / sqrt(1 - _betta2_t);
float step_size = -1 * _alpha / bias_correction1;
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
AVX_Data weight_decay4;
if (_weight_decay > 0) weight_decay4.data = SIMD_SET(_weight_decay);
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
AVX_Data grad_4;
grad_4.data = SIMD_LOAD(grads + i);
AVX_Data momentum_4;
momentum_4.data = SIMD_LOAD(_exp_avg + i);
AVX_Data variance_4;
variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
AVX_Data param_4;
param_4.data = SIMD_LOAD(_params + i);
if (_weight_decay > 0)
grad_4.data = SIMD_FMA(param_4.data, weight_decay4.data, grad_4.data);
momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
momentum_4.data = SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);
variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data);
grad_4.data = SIMD_MUL(grad_4.data, grad_4.data);
variance_4.data = SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data);
grad_4.data = SIMD_SQRT(variance_4.data);
grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
SIMD_STORE(_params + i, param_4.data);
if (dev_params) SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4.data);
SIMD_STORE(_exp_avg + i, momentum_4.data);
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + t,
copy_size,
Context::Instance().GetCurrentStream());
_buf_index = !_buf_index;
}
}
#endif
if (_param_size > rounded_size) {
#pragma omp parallel for
for (size_t k = rounded_size; k < _param_size; k++) {
float grad = grads[k];
float param = _params[k];
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0) grad = param * _weight_decay + grad;
momentum *= momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * bias_correction2 + _eps;
grad = momentum / grad;
param = grad * step_size + param;
if (dev_params) _doubled_buffer[_buf_index][k - rounded_size] = (__half)param;
_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + rounded_size,
(_param_size - rounded_size),
Context::Instance().GetCurrentStream());
}
}
}
void Adam_Optimizer::Step_4(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
float bias_correction1 = 1 - _betta1_t;
float bias_correction2 = 1 / sqrt(1 - _betta2_t);
// AVX_Data bias_correction1_4 = SIMD_SET(bias_correction1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 2));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) {
AVX_Data grad_4[4];
grad_4[0].data = SIMD_LOAD(grads + i);
grad_4[1].data = SIMD_LOAD(grads + i + SIMD_WIDTH);
grad_4[2].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 1));
grad_4[3].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 3);
AVX_Data momentum_4[4];
momentum_4[0].data = SIMD_LOAD(_exp_avg + i);
momentum_4[1].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH);
momentum_4[2].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 1));
momentum_4[3].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 3);
AVX_Data variance_4[4];
variance_4[0].data = SIMD_LOAD(_exp_avg_sq + i);
variance_4[1].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH);
variance_4[2].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 1));
variance_4[3].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 3);
AVX_Data param_4[4];
param_4[0].data = SIMD_LOAD(_params + i);
param_4[1].data = SIMD_LOAD(_params + i + SIMD_WIDTH);
param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1));
param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3);
if (_weight_decay > 0) {
AVX_Data weight_decay4;
weight_decay4.data = SIMD_SET(_weight_decay);
grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data);
grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data);
grad_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, grad_4[3].data);
}
momentum_4[0].data = SIMD_MUL(momentum_4[0].data, betta1_4.data);
momentum_4[0].data = SIMD_FMA(grad_4[0].data, betta1_minus1_4.data, momentum_4[0].data);
momentum_4[1].data = SIMD_MUL(momentum_4[1].data, betta1_4.data);
momentum_4[1].data = SIMD_FMA(grad_4[1].data, betta1_minus1_4.data, momentum_4[1].data);
momentum_4[2].data = SIMD_MUL(momentum_4[2].data, betta1_4.data);
momentum_4[2].data = SIMD_FMA(grad_4[2].data, betta1_minus1_4.data, momentum_4[2].data);
momentum_4[3].data = SIMD_MUL(momentum_4[3].data, betta1_4.data);
momentum_4[3].data = SIMD_FMA(grad_4[3].data, betta1_minus1_4.data, momentum_4[3].data);
variance_4[0].data = SIMD_MUL(variance_4[0].data, betta2_4.data);
variance_4[1].data = SIMD_MUL(variance_4[1].data, betta2_4.data);
variance_4[2].data = SIMD_MUL(variance_4[2].data, betta2_4.data);
variance_4[3].data = SIMD_MUL(variance_4[3].data, betta2_4.data);
grad_4[0].data = SIMD_MUL(grad_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_MUL(grad_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_MUL(grad_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_MUL(grad_4[3].data, grad_4[3].data);
variance_4[0].data = SIMD_FMA(grad_4[0].data, betta2_minus1_4.data, variance_4[0].data);
variance_4[1].data = SIMD_FMA(grad_4[1].data, betta2_minus1_4.data, variance_4[1].data);
variance_4[2].data = SIMD_FMA(grad_4[2].data, betta2_minus1_4.data, variance_4[2].data);
variance_4[3].data = SIMD_FMA(grad_4[3].data, betta2_minus1_4.data, variance_4[3].data);
grad_4[0].data = SIMD_SQRT(variance_4[0].data);
grad_4[1].data = SIMD_SQRT(variance_4[1].data);
grad_4[2].data = SIMD_SQRT(variance_4[2].data);
grad_4[3].data = SIMD_SQRT(variance_4[3].data);
grad_4[0].data = SIMD_FMA(grad_4[0].data, bias2_sqrt.data, eps_4.data);
grad_4[1].data = SIMD_FMA(grad_4[1].data, bias2_sqrt.data, eps_4.data);
grad_4[2].data = SIMD_FMA(grad_4[2].data, bias2_sqrt.data, eps_4.data);
grad_4[3].data = SIMD_FMA(grad_4[3].data, bias2_sqrt.data, eps_4.data);
grad_4[0].data = SIMD_DIV(momentum_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_DIV(momentum_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data);
param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(grad_4[3].data, step_size_4.data, param_4[3].data);
SIMD_STORE(_params + i, param_4[0].data);
SIMD_STORE(_params + i + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_params + i + (SIMD_WIDTH << 1), param_4[2].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 3, param_4[3].data);
if (dev_params) {
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4[0].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1),
param_4[2].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3].data);
}
SIMD_STORE(_exp_avg + i, momentum_4[0].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH, momentum_4[1].data);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 1), momentum_4[2].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 3, momentum_4[3].data);
SIMD_STORE(_exp_avg_sq + i, variance_4[0].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH, variance_4[1].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 1), variance_4[2].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 3, variance_4[3].data);
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + t,
copy_size,
Context::Instance().GetCurrentStream());
_buf_index = !_buf_index;
}
}
#endif
if (_param_size > rounded_size)
Step((_params + rounded_size),
(grads + rounded_size),
(_exp_avg + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params));
}
int create_adam_optimizer(int optimizer_id,
float alpha = 1e-3,
float betta1 = 0.9,
float betta2 = 0.999,
float eps = 1e-8,
float weight_decay = 0)
{
auto opt = std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps, weight_decay);
s_optimizers[optimizer_id] = opt;
#if defined(__AVX512__)
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with AVX512 arithmetic capability." << std::endl;
#else
#if defined(__AVX256__)
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with AVX2 arithmetic capability." << std::endl;
#else
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with scalar arithmetic capability." << std::endl;
#endif
#endif
return 0;
}
void Adam_Optimizer::Step_8(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
float bias_correction1 = 1 - _betta1_t;
float bias_correction2 = 1 / sqrt(1 - _betta2_t);
// AVX_Data bias_correction1_4 = SIMD_SET(bias_correction1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 3));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) {
AVX_Data grad_4[8];
grad_4[0].data = SIMD_LOAD(grads + i);
grad_4[1].data = SIMD_LOAD(grads + i + SIMD_WIDTH);
grad_4[2].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 1));
grad_4[3].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 3);
grad_4[4].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 2));
grad_4[5].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 5);
grad_4[6].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 6);
grad_4[7].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 7);
AVX_Data momentum_4[8];
momentum_4[0].data = SIMD_LOAD(_exp_avg + i);
momentum_4[1].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH);
momentum_4[2].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 1));
momentum_4[3].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 3);
momentum_4[4].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 2));
momentum_4[5].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 5);
momentum_4[6].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 6);
momentum_4[7].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 7);
AVX_Data variance_4[8];
variance_4[0].data = SIMD_LOAD(_exp_avg_sq + i);
variance_4[1].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH);
variance_4[2].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 1));
variance_4[3].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 3);
variance_4[4].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 2));
variance_4[5].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 5);
variance_4[6].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 6);
variance_4[7].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 7);
AVX_Data param_4[8];
param_4[0].data = SIMD_LOAD(_params + i);
param_4[1].data = SIMD_LOAD(_params + i + SIMD_WIDTH);
param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1));
param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3);
param_4[4].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 2));
param_4[5].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 5);
param_4[6].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 6);
param_4[7].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 7);
if (_weight_decay > 0) {
AVX_Data weight_decay4;
weight_decay4.data = SIMD_SET(_weight_decay);
grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data);
grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data);
grad_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, grad_4[3].data);
grad_4[4].data = SIMD_FMA(param_4[4].data, weight_decay4.data, grad_4[4].data);
grad_4[5].data = SIMD_FMA(param_4[5].data, weight_decay4.data, grad_4[5].data);
grad_4[6].data = SIMD_FMA(param_4[6].data, weight_decay4.data, grad_4[6].data);
grad_4[7].data = SIMD_FMA(param_4[7].data, weight_decay4.data, grad_4[7].data);
}
momentum_4[0].data = SIMD_MUL(momentum_4[0].data, betta1_4.data);
momentum_4[0].data = SIMD_FMA(grad_4[0].data, betta1_minus1_4.data, momentum_4[0].data);
momentum_4[1].data = SIMD_MUL(momentum_4[1].data, betta1_4.data);
momentum_4[1].data = SIMD_FMA(grad_4[1].data, betta1_minus1_4.data, momentum_4[1].data);
momentum_4[2].data = SIMD_MUL(momentum_4[2].data, betta1_4.data);
momentum_4[2].data = SIMD_FMA(grad_4[2].data, betta1_minus1_4.data, momentum_4[2].data);
momentum_4[3].data = SIMD_MUL(momentum_4[3].data, betta1_4.data);
momentum_4[3].data = SIMD_FMA(grad_4[3].data, betta1_minus1_4.data, momentum_4[3].data);
momentum_4[4].data = SIMD_MUL(momentum_4[4].data, betta1_4.data);
momentum_4[4].data = SIMD_FMA(grad_4[4].data, betta1_minus1_4.data, momentum_4[4].data);
momentum_4[5].data = SIMD_MUL(momentum_4[5].data, betta1_4.data);
momentum_4[5].data = SIMD_FMA(grad_4[5].data, betta1_minus1_4.data, momentum_4[5].data);
momentum_4[6].data = SIMD_MUL(momentum_4[6].data, betta1_4.data);
momentum_4[6].data = SIMD_FMA(grad_4[6].data, betta1_minus1_4.data, momentum_4[6].data);
momentum_4[7].data = SIMD_MUL(momentum_4[7].data, betta1_4.data);
momentum_4[7].data = SIMD_FMA(grad_4[7].data, betta1_minus1_4.data, momentum_4[7].data);
variance_4[0].data = SIMD_MUL(variance_4[0].data, betta2_4.data);
variance_4[1].data = SIMD_MUL(variance_4[1].data, betta2_4.data);
variance_4[2].data = SIMD_MUL(variance_4[2].data, betta2_4.data);
variance_4[3].data = SIMD_MUL(variance_4[3].data, betta2_4.data);
variance_4[4].data = SIMD_MUL(variance_4[4].data, betta2_4.data);
variance_4[5].data = SIMD_MUL(variance_4[5].data, betta2_4.data);
variance_4[6].data = SIMD_MUL(variance_4[6].data, betta2_4.data);
variance_4[7].data = SIMD_MUL(variance_4[7].data, betta2_4.data);
grad_4[0].data = SIMD_MUL(grad_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_MUL(grad_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_MUL(grad_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_MUL(grad_4[3].data, grad_4[3].data);
grad_4[4].data = SIMD_MUL(grad_4[4].data, grad_4[4].data);
grad_4[5].data = SIMD_MUL(grad_4[5].data, grad_4[5].data);
grad_4[6].data = SIMD_MUL(grad_4[6].data, grad_4[6].data);
grad_4[7].data = SIMD_MUL(grad_4[7].data, grad_4[7].data);
variance_4[0].data = SIMD_FMA(grad_4[0].data, betta2_minus1_4.data, variance_4[0].data);
variance_4[1].data = SIMD_FMA(grad_4[1].data, betta2_minus1_4.data, variance_4[1].data);
variance_4[2].data = SIMD_FMA(grad_4[2].data, betta2_minus1_4.data, variance_4[2].data);
variance_4[3].data = SIMD_FMA(grad_4[3].data, betta2_minus1_4.data, variance_4[3].data);
variance_4[4].data = SIMD_FMA(grad_4[4].data, betta2_minus1_4.data, variance_4[4].data);
variance_4[5].data = SIMD_FMA(grad_4[5].data, betta2_minus1_4.data, variance_4[5].data);
variance_4[6].data = SIMD_FMA(grad_4[6].data, betta2_minus1_4.data, variance_4[6].data);
variance_4[7].data = SIMD_FMA(grad_4[7].data, betta2_minus1_4.data, variance_4[7].data);
grad_4[0].data = SIMD_SQRT(variance_4[0].data);
grad_4[1].data = SIMD_SQRT(variance_4[1].data);
grad_4[2].data = SIMD_SQRT(variance_4[2].data);
grad_4[3].data = SIMD_SQRT(variance_4[3].data);
grad_4[4].data = SIMD_SQRT(variance_4[4].data);
grad_4[5].data = SIMD_SQRT(variance_4[5].data);
grad_4[6].data = SIMD_SQRT(variance_4[6].data);
grad_4[7].data = SIMD_SQRT(variance_4[7].data);
grad_4[0].data = SIMD_FMA(grad_4[0].data, bias2_sqrt.data, eps_4.data);
grad_4[1].data = SIMD_FMA(grad_4[1].data, bias2_sqrt.data, eps_4.data);
grad_4[2].data = SIMD_FMA(grad_4[2].data, bias2_sqrt.data, eps_4.data);
grad_4[3].data = SIMD_FMA(grad_4[3].data, bias2_sqrt.data, eps_4.data);
grad_4[4].data = SIMD_FMA(grad_4[4].data, bias2_sqrt.data, eps_4.data);
grad_4[5].data = SIMD_FMA(grad_4[5].data, bias2_sqrt.data, eps_4.data);
grad_4[6].data = SIMD_FMA(grad_4[6].data, bias2_sqrt.data, eps_4.data);
grad_4[7].data = SIMD_FMA(grad_4[7].data, bias2_sqrt.data, eps_4.data);
grad_4[0].data = SIMD_DIV(momentum_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_DIV(momentum_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data);
grad_4[4].data = SIMD_DIV(momentum_4[4].data, grad_4[4].data);
grad_4[5].data = SIMD_DIV(momentum_4[5].data, grad_4[5].data);
grad_4[6].data = SIMD_DIV(momentum_4[6].data, grad_4[6].data);
grad_4[7].data = SIMD_DIV(momentum_4[7].data, grad_4[7].data);
param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(grad_4[3].data, step_size_4.data, param_4[3].data);
param_4[4].data = SIMD_FMA(grad_4[4].data, step_size_4.data, param_4[4].data);
param_4[5].data = SIMD_FMA(grad_4[5].data, step_size_4.data, param_4[5].data);
param_4[6].data = SIMD_FMA(grad_4[6].data, step_size_4.data, param_4[6].data);
param_4[7].data = SIMD_FMA(grad_4[7].data, step_size_4.data, param_4[7].data);
SIMD_STORE(_params + i, param_4[0].data);
SIMD_STORE(_params + i + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_params + i + (SIMD_WIDTH << 1), param_4[2].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 3, param_4[3].data);
SIMD_STORE(_params + i + (SIMD_WIDTH << 2), param_4[4].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 5, param_4[5].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 6, param_4[6].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 7, param_4[7].data);
if (dev_params) {
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4[0].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1),
param_4[2].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 2),
param_4[4].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 5, param_4[5].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 6, param_4[6].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 7, param_4[7].data);
}
SIMD_STORE(_exp_avg + i, momentum_4[0].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH, momentum_4[1].data);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 1), momentum_4[2].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 3, momentum_4[3].data);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 2), momentum_4[4].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 5, momentum_4[5].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 6, momentum_4[6].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 7, momentum_4[7].data);
SIMD_STORE(_exp_avg_sq + i, variance_4[0].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH, variance_4[1].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 1), variance_4[2].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 3, variance_4[3].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 2), variance_4[4].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 5, variance_4[5].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 6, variance_4[6].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 7, variance_4[7].data);
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + t,
copy_size,
Context::Instance().GetCurrentStream());
_buf_index = !_buf_index;
}
}
#endif
if (_param_size > rounded_size)
Step_4((_params + rounded_size),
(grads + rounded_size),
(_exp_avg + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params));
}
int ds_adam_step(int optimizer_id,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq)
{
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep();
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0));
return 0;
}
int ds_adam_step_plus_copy(int optimizer_id,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq,
torch::Tensor& gpu_params)
{
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
auto grads_c = grads.contiguous();
float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
__half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep();
opt->Step_8(
params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr);
return 0;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
m.def("adam_update_copy",
&ds_adam_step_plus_copy,
"DeepSpeed CPU Adam update and param copy (C++)");
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
}
#include "custom_cuda_layers.h"
__global__ void param_update_kernel(const float* input, __half* output, int size)
{
const float4* input_cast = reinterpret_cast<const float4*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
int id = blockIdx.x * blockDim.x + threadIdx.x;
if (id < size) {
float4 data = input_cast[id];
float2 cast_data;
__half* output_h = reinterpret_cast<__half*>(&cast_data);
output_h[0] = (__half)data.x;
output_h[1] = (__half)data.y;
output_h[2] = (__half)data.z;
output_h[3] = (__half)data.w;
output_cast[id] = cast_data;
}
}
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream)
{
int threads = 512;
size /= 4;
dim3 grid_dim((size - 1) / threads + 1);
dim3 block_dim(threads);
param_update_kernel<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
}
#pragma once
#include <cpuid.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
#include <x86intrin.h>
#include <cassert>
#include "context.h"
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#define TILE (1024 * 1024 * 1024)
#if defined(__AVX512__)
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
#define SIMD_SET(x) _mm512_set1_ps(x)
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_WIDTH 16
#else
#if defined(__AVX256__)
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm256_loadu_ps(x)
#define SIMD_SET(x) _mm256_set1_ps(x)
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_WIDTH 8
#endif
#endif
class Adam_Optimizer {
public:
Adam_Optimizer(float alpha = 1e-3,
float betta1 = 0.9,
float betta2 = 0.999,
float eps = 1e-8,
float weight_decay = 0)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_buf_index(false)
{
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
}
~Adam_Optimizer()
{
cudaFreeHost(_doubled_buffer[0]);
cudaFreeHost(_doubled_buffer[1]);
}
void Step(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t param_size,
__half* dev_param = nullptr);
void Step_4(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sa,
size_t param_size,
__half* dev_param = nullptr);
void Step_8(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params = nullptr);
inline void IncrementStep()
{
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
private:
#if defined(__AVX512__) or defined(__AVX256__)
union AVX_Data {
#if defined(__AVX512__)
__m512 data;
#else
__m256 data;
#endif
// float data_f[16];
};
#endif
float _alpha;
float _betta1;
float _betta2;
float _eps;
float _weight_decay;
float _betta1_t;
float _betta2_t;
float* _doubled_buffer[2];
bool _buf_index;
};
......@@ -264,3 +264,5 @@ void launch_fuse_transpose_bias_kernel(const T* inp,
int rows,
int cols,
cudaStream_t stream);
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream);
......@@ -4,16 +4,18 @@ Copyright 2020 The Microsoft DeepSpeed Team
import sys
import types
from deepspeed.runtime.engine import DeepSpeedEngine
from deepspeed.runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from deepspeed.runtime.lr_schedules import add_tuning_arguments
from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.runtime.activation_checkpointing import checkpointing
from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from deepspeed.utils import logger
from . import ops
from .runtime.engine import DeepSpeedEngine
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_ADAM
from .runtime.lr_schedules import add_tuning_arguments
from .runtime.config import DeepSpeedConfig
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .utils import logger
try:
from deepspeed.git_version_info import version, git_hash, git_branch
from .git_version_info import version, git_hash, git_branch
except ImportError:
version = "0.0.0+unknown"
git_hash = None
......
from ..git_version_info import installed_ops as __installed_ops__
from . import lamb
from . import transformer
if __installed_ops__['sparse-attn']:
from . import sparse_attention
if __installed_ops__['cpu-adam']:
from . import adam
from .cpu_adam import DeepSpeedCPUAdam
import math
import torch
import importlib
ds_opt_adam = None
class DeepSpeedCPUAdam(torch.optim.Optimizer):
optimizer_id = 0
def __init__(self,
model_params,
lr=1e-3,
betas=(0.9,
0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False):
default_args = dict(lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad)
super(DeepSpeedCPUAdam, self).__init__(model_params, default_args)
self.opt_id = DeepSpeedCPUAdam.optimizer_id
DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1
global ds_opt_adam
ds_opt_adam = importlib.import_module('deepspeed.ops.adam.cpu_adam_op')
ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay)
def __setstate__(self, state):
super(DeepSpeedCPUAdam, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None, fp16_param_groups=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group_id, group in enumerate(self.param_groups):
for param_id, p in enumerate(group['params']):
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
# State initialization
if len(state) == 0:
print(f'group {group_id} param {param_id} = {p.numel()}')
state['step'] = 0
# gradient momentums
state['exp_avg'] = torch.zeros_like(p.data, device='cpu')
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data, device='cpu')
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
state['step'] += 1
if fp16_param_groups is not None:
p_fp16 = fp16_param_groups[group_id][param_id]
ds_opt_adam.adam_update_copy(self.opt_id,
p.data,
grad,
exp_avg,
exp_avg_sq,
p_fp16)
else:
ds_opt_adam.adam_update(self.opt_id,
p.data,
grad,
exp_avg,
exp_avg_sq)
return loss
import torch
from torch.autograd import Variable
import collections
def async_migrate_to(obj, dev, main_stream=None):
if torch.is_tensor(obj):
obj = Variable(obj)
if isinstance(obj, Variable):
v = obj.cuda(dev, async=True)
if main_stream is not None:
v.data.record_stream(main_stream)
return v
elif isinstance(obj, collections.Mapping):
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
elif isinstance(obj, collections.Sequence):
return [async_copy_to(o, dev, main_stream) for o in obj]
else:
return obj
def async_copy_to(obj, dev, main_stream=None):
if torch.is_tensor(obj):
obj = Variable(obj)
if isinstance(obj, Variable):
target = torch.empty_like(obj, device=dev).copy_(obj)
if main_stream is not None:
target.data.record_stream(main_stream)
return target
elif isinstance(obj, collections.Mapping):
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
elif isinstance(obj, collections.Sequence):
return [async_copy_to(o, dev, main_stream) for o in obj]
......@@ -10,14 +10,21 @@ from deepspeed.runtime.constants import *
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
from deepspeed.runtime.config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from deepspeed.runtime.zero.constants import *
from deepspeed.runtime.activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
from deepspeed.utils import logger
TENSOR_CORE_ALIGN_SIZE = 8
ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
ADAM_OPTIMIZER = 'adam'
LAMB_OPTIMIZER = 'lamb'
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER]
ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
DEEPSPEED_ADAM = 'deepspeed_adam'
DEEPSPEED_OPTIMIZERS = [
ADAM_OPTIMIZER,
LAMB_OPTIMIZER,
ONEBIT_ADAM_OPTIMIZER,
DEEPSPEED_ADAM
]
def get_amp_enabled(param_dict):
......@@ -111,22 +118,9 @@ def get_zero_optimization(param_dict):
def get_zero_reduce_scatter(param_dict):
return get_scalar_param(param_dict, ZERO_REDUCE_SCATTER, ZERO_REDUCE_SCATTER_DEFAULT)
def get_zero_max_elements_per_comm(param_dict):
return get_scalar_param(param_dict,
ZERO_MAX_ELEMENTS_PER_COMM,
ZERO_MAX_ELEMENTS_PER_COMM_DEFAULT)
def get_allgather_size(param_dict):
return get_scalar_param(param_dict,
ALLGATHER_SIZE,
ALLGATHER_SIZE_DEFAULT) if get_scalar_param(
param_dict,
ALLGATHER_SIZE,
ALLGATHER_SIZE_DEFAULT) > 0 else ALLGATHER_SIZE_DEFAULT
ZERO_OPTIMIZATION_REDUCE_SCATTER,
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT)
def get_allreduce_always_fp32(param_dict):
......@@ -493,8 +487,6 @@ class DeepSpeedConfig(object):
self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict)
self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)
self.allgather_size = get_allgather_size(param_dict)
self.zero_config = DeepSpeedZeroConfig(param_dict)
self.zero_optimization_stage = self.zero_config.stage
self.zero_enabled = self.zero_optimization_stage > 0
......@@ -628,15 +620,18 @@ class DeepSpeedConfig(object):
':'))))
def _do_error_check(self):
if self.zero_enabled:
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
assert self.gradient_accumulation_steps, 'DeepSpeedConfig: {} is not defined'.format(
assert self.gradient_accumulation_steps, "DeepSpeedConfig: {} is not defined".format(
GRADIENT_ACCUMULATION_STEPS)
if self.zero_enabled:
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
if self.zero_config.cpu_offload is True:
assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS)
#assert self.gradient_accumulation_steps == 1, "DeepSpeedConfig: {}is not supported for {}".format(GRADIENT_ACCUMULATION_STEPS, ZERO_OPTIMIZATION_CPU_OFFLOAD)
def _do_warning_check(self):
fp16_enabled = self.fp16_enabled or self.zero_enabled
......
......@@ -183,35 +183,6 @@ Gradient clipping should be enabled as:
GRADIENT_CLIPPING = 'gradient_clipping'
GRADIENT_CLIPPING_DEFAULT = 0.
#########################################
# ZeRO optimization
#########################################
# ZeRO optimization. By default, this optimization is not enabled.
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
ZERO_FORMAT = '''
ZeRO optimization should be enabled as:
"session_params": {
"zero_optimization": [0|1|2],
"zero_all_gather_size": 200
}
'''
ZERO_OPTIMIZATION = 'zero_optimization'
ZERO_OPTIMIZATION_DEFAULT = 0
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
ZERO_OPTIMIZATION_GRADIENTS = 2
ZERO_OPTIMIZATION_WEIGHTS = 3
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
ZERO_REDUCE_SCATTER = "zero_reduce_scatter"
ZERO_REDUCE_SCATTER_DEFAULT = True
ZERO_MAX_ELEMENTS_PER_COMM = "zero_max_elements_per_comm"
ZERO_MAX_ELEMENTS_PER_COMM_DEFAULT = 5e8
ALLGATHER_SIZE = 'allgather_size'
ALLGATHER_SIZE_DEFAULT = 500000000
#########################################
# FP32 AllReduce
#########################################
......
......@@ -7,6 +7,7 @@ import torch
import warnings
import torch.distributed as dist
import apex
from apex import amp
from torch.nn.modules import Module
from torch.distributed.distributed_c10d import _get_global_rank
......@@ -14,20 +15,20 @@ from tensorboardX import SummaryWriter
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, DEEPSPEED_OPTIMIZERS
ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, DEEPSPEED_ADAM, DEEPSPEED_OPTIMIZERS
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
TORCH_DISTRIBUTED_DEFAULT_PORT, \
TORCH_DISTRIBUTED_DEFAULT_PORT
from deepspeed.runtime.zero.constants import \
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.utils import logger
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
......@@ -105,7 +106,6 @@ class DeepSpeedEngine(Module):
collate_fn=None,
config_params=None):
super(DeepSpeedEngine, self).__init__()
self.client_optimizer = optimizer
self.client_model_parameters = model_parameters
self.client_lr_scheduler = lr_scheduler
......@@ -266,7 +266,7 @@ class DeepSpeedEngine(Module):
return self._config.train_micro_batch_size_per_gpu
def optimizer_name(self):
return self._config.optimizer_name
return self.client_optimizer.__class__.__name__ if self.client_optimizer else self._config.optimizer_name
def optimizer_params(self):
return self._config.optimizer_params
......@@ -292,6 +292,9 @@ class DeepSpeedEngine(Module):
def zero_overlap_comm(self):
return self._config.zero_config.overlap_comm
def zero_cpu_offload(self):
return self._config.zero_config.cpu_offload
def zero_optimization_stage(self):
return self._config.zero_optimization_stage
......@@ -310,9 +313,6 @@ class DeepSpeedEngine(Module):
def zero_load_from_fp32_weights(self):
return self._config.zero_config.load_from_fp32_weights
def allgather_size(self):
return self._config.allgather_size
def fp16_enabled(self):
return self._config.fp16_enabled
......@@ -491,6 +491,7 @@ class DeepSpeedEngine(Module):
# Configure optimizer
def _configure_optimizer(self, client_optimizer, model_parameters):
if client_optimizer is not None:
basic_optimizer = client_optimizer
logger.info('Using client Optimizer as basic optimizer')
......@@ -504,13 +505,14 @@ class DeepSpeedEngine(Module):
if self.zero_optimization():
assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
if self.optimizer_name() != ADAM_OPTIMIZER:
if not is_zero_supported_optimizer(basic_optimizer):
assert self.zero_allow_untested_optimizer(), \
'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'
logger.warning(
"**** You are using ZeRO with an untested optimizer, proceed with caution *****"
)
self.optimizer = self._configure_zero_optimizer(basic_optimizer)
elif self.amp_enabled():
assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode"
......@@ -522,8 +524,8 @@ class DeepSpeedEngine(Module):
self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
else:
self.optimizer = basic_optimizer
# logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict()))
logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer))
logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict()))
def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params()
......@@ -533,8 +535,14 @@ class DeepSpeedEngine(Module):
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
)
if self.optimizer_name() == ADAM_OPTIMIZER:
from apex.optimizers.fused_adam import FusedAdam
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
if self.zero_cpu_offload():
optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
else:
from apex.optimizers.fused_adam import FusedAdam
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == DEEPSPEED_ADAM:
from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer = DeepSpeedCPUAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER:
from deepspeed.ops.lamb import FusedLamb
optimizer = FusedLamb(model_parameters, **optimizer_parameters)
......@@ -550,8 +558,9 @@ class DeepSpeedEngine(Module):
initial_dynamic_scale = self.initial_dynamic_scale()
dynamic_loss_args = self.dynamic_loss_scale_args()
clip_grad = self.gradient_clipping()
if self.optimizer_name() == ADAM_OPTIMIZER or self.optimizer_name(
) == ONEBIT_ADAM_OPTIMIZER:
if isinstance(optimizer,
apex.optimizers.FusedAdam) or self.optimizer_name(
) == ONEBIT_ADAM_OPTIMIZER:
if self.dynamic_loss_scale():
logger.info('Creating fp16 optimizer with dynamic loss scale')
timers = self.timers if self.wall_clock_breakdown() else None
......@@ -616,9 +625,11 @@ class DeepSpeedEngine(Module):
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
cpu_offload=self.zero_cpu_offload(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor())
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps())
else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
......@@ -724,7 +735,6 @@ class DeepSpeedEngine(Module):
return loss
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
#Zero stage 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
......@@ -780,6 +790,8 @@ class DeepSpeedEngine(Module):
self.timers('backward_inner').start()
if self.zero_optimization():
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary(
)
self.optimizer.backward(loss)
elif self.amp_enabled():
# AMP requires delaying unscale when inside gradient accumulation boundaries
......@@ -854,7 +866,6 @@ class DeepSpeedEngine(Module):
master_params = amp.master_params(self.optimizer)
torch.nn.utils.clip_grad_norm_(parameters=master_params,
max_norm=self.gradient_clipping())
self.optimizer.step()
#zero grad in basic optimizer could be unreliable and may not exhibit
......@@ -957,6 +968,9 @@ class DeepSpeedEngine(Module):
def get_lr(self):
return self._get_optimizer_param('lr')
def get_type(self):
return self._get_optimizer_param('type')
def get_mom(self):
return self._get_optimizer_param('betas')
......
......@@ -5,79 +5,7 @@ Licensed under the MIT license.
from deepspeed.runtime.config_utils import get_scalar_param
from deepspeed.utils import logger
#########################################
# ZeRO optimization
#########################################
# ZeRO optimization. By default, this optimization is not enabled.
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
ZERO_FORMAT = '''
ZeRO optimization should be enabled as:
"session_params": {
"zero_optimization": {
"stage": [0|1|2],
"allgather_partitions": [true|false],
"allgather_bucket_size": 500000000,
"reduce_scatter": [true|false],
"contiguous_gradients" : [true|false]
"overlap_comm": [true|false],
"reduce_bucket_size": 500000000
"load_from_fp32_weights": [true|false]
}
}
'''
ZERO_OPTIMIZATION = 'zero_optimization'
ZERO_OPTIMIZATION_DISABLED = 0
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
ZERO_OPTIMIZATION_GRADIENTS = 2
ZERO_OPTIMIZATION_WEIGHTS = 3
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
ZERO_OPTIMIZATION_STAGE = 'stage'
ZERO_OPTIMIZATION_STAGE_1 = 'stage_1'
ZERO_OPTIMIZATION_STAGE_2 = 'stage_2'
ZERO_OPTIMIZATION_STAGE_3 = 'stage_3'
ZERO_OPTIMIZATION_STAGE_DEFAULT = ZERO_OPTIMIZATION_DISABLED
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS = 'allgather_partitions'
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT = True
ZERO_OPTIMIZATION_REDUCE_SCATTER = 'reduce_scatter'
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT = True
ZERO_OPTIMIZATION_OVERLAP_COMM = 'overlap_comm'
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT = False
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS = 'contiguous_gradients'
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT = False
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE = 'reduce_bucket_size'
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE = 'allgather_bucket_size'
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED = 'allgather_size'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS = 'load_from_fp32_weights'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True
ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE:
ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER:
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE:
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT
}
from deepspeed.runtime.zero.constants import *
class DeepSpeedZeroConfig(object):
......@@ -92,6 +20,7 @@ class DeepSpeedZeroConfig(object):
self.allgather_bucket_size = None
self.overlap_comm = None
self.load_from_fp32_weights = None
self.cpu_offload = None
if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION]
......@@ -156,7 +85,12 @@ class DeepSpeedZeroConfig(object):
zero_config_dict,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)
self.load_from_fp32_weights = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT)
self.cpu_offload = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_CPU_OFFLOAD,
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT)
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
#########################################
# ZeRO optimization
#########################################
# ZeRO optimization. By default, this optimization is not enabled.
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
ZERO_FORMAT = '''
ZeRO optimization should be enabled as:
"session_params": {
"zero_optimization": {
"stage": [0|1|2],
"allgather_partitions": [true|false],
"allgather_bucket_size": 500000000,
"reduce_scatter": [true|false],
"contiguous_gradients" : [true|false]
"overlap_comm": [true|false],
"reduce_bucket_size": 500000000
"load_from_fp32_weights": [true|false]
"cpu_offload": [true|false]
}
}
'''
ZERO_OPTIMIZATION = 'zero_optimization'
ZERO_OPTIMIZATION_DISABLED = 0
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
ZERO_OPTIMIZATION_GRADIENTS = 2
ZERO_OPTIMIZATION_WEIGHTS = 3
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
ZERO_OPTIMIZATION_STAGE = 'stage'
ZERO_OPTIMIZATION_STAGE_1 = 'stage_1'
ZERO_OPTIMIZATION_STAGE_2 = 'stage_2'
ZERO_OPTIMIZATION_STAGE_3 = 'stage_3'
ZERO_OPTIMIZATION_STAGE_DEFAULT = ZERO_OPTIMIZATION_DISABLED
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS = 'allgather_partitions'
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT = True
ZERO_OPTIMIZATION_REDUCE_SCATTER = 'reduce_scatter'
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT = True
ZERO_OPTIMIZATION_OVERLAP_COMM = 'overlap_comm'
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT = False
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS = 'contiguous_gradients'
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT = False
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE = 'reduce_bucket_size'
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE = 'allgather_bucket_size'
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED = 'allgather_size'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS = 'load_from_fp32_weights'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True
ZERO_OPTIMIZATION_CPU_OFFLOAD = 'cpu_offload'
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT = False
ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT,
ZERO_OPTIMIZATION_CPU_OFFLOAD: ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT
}
......@@ -793,8 +793,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
def _get_state_without_padding(self, state_with_padding, padding):
lean_state = {}
for key, value in state_with_padding.items():
lean_length = value.numel() - padding
lean_state[key] = value[:lean_length]
if torch.is_tensor(value):
lean_length = value.numel() - padding
lean_state[key] = value[:lean_length]
else:
lean_state[key] = value
return lean_state
......
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed.distributed_c10d import _get_global_rank
import torch.distributed as dist
import math
from torch._six import inf
from torch.autograd import Variable
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.utils import logger
#Toggle this to true to enable correctness test
#with gradient partitioning and without
pg_correctness_test = False
try:
from apex_C import flatten
from apex_C import unflatten
except ImportError:
try:
_ = warned_flatten
except NameError:
logger.warning(
"apex was installed without --cpp_ext. Falling back to Python flatten and unflatten."
)
warned_flatten = True
from torch._utils import _flatten_dense_tensors as flatten
from torch._utils import _unflatten_dense_tensors as unflatten
def input(msg):
return
def split_half_float_double(tensors):
dtypes = [
"torch.cuda.HalfTensor",
"torch.cuda.FloatTensor",
"torch.cuda.DoubleTensor"
]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
return buckets
def isclose(a, b, rtol=1e-09, atol=0.0):
return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)
def lcm(x, y):
from fractions import gcd # or can import gcd from `math` in Python 3
return x * y // gcd(x, y)
# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(tensor_list, alignment):
num_elements = 0
for tensor in tensor_list:
num_elements = num_elements + tensor.numel()
remaining = num_elements % alignment
if remaining:
elements_to_add = alignment - remaining
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
num_elements = num_elements + elements_to_add
else:
padded_tensor_list = tensor_list
return _flatten_dense_tensors(padded_tensor_list)
def get_alignment_padding(tensor_list, alignment):
num_elements = sum([tensor.numel() for tensor in tensor_list])
remainder = num_elements % alignment
return (alignment - remainder) if remainder else remainder
def move_to_cpu(tensor_list):
for tensor in tensor_list:
tensor.data = tensor.data.cpu()
def print_rank_msg(msg):
print(f"rank {dist.get_rank()} - {msg}")
class FP16_DeepSpeedZeroOptimizer(object):
"""
DeepSpeedZeroOptimizer designed to reduce the memory footprint
required for training large deep learning models.
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
For usage examples, refer to TODO: DeepSpeed Tutorial
"""
def __init__(self,
init_optimizer,
timers,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
contiguous_gradients=True,
reduce_bucket_size=500000000,
allgather_bucket_size=5000000000,
dp_process_group=None,
reduce_scatter=True,
overlap_comm=False,
mpu=None,
clip_grad=0.0,
allreduce_always_fp32=False,
postscale_gradients=True,
gradient_predivide_factor=1.0):
if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
logger.info(f"Allgather bucket size {allgather_bucket_size}")
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add ne552w fused optimizer later
# differences from apex.fp16_utils:
# - assume all model params in fp16
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
self.timers = timers
self.reduce_scatter = reduce_scatter
self.overlap_comm = overlap_comm
self.dp_process_group = dp_process_group
self.partition_count = dist.get_world_size(group=self.dp_process_group)
if mpu is None:
self.model_parallel_group = None
self.model_parallel_rank = 0
else:
self.model_parallel_group = mpu.get_model_parallel_group()
self.model_parallel_rank = mpu.get_model_parallel_rank()
self.overflow = False
self.clip_grad = clip_grad
self.allreduce_always_fp32 = allreduce_always_fp32
self.gradient_predivide_factor = gradient_predivide_factor
self.postscale_gradients = postscale_gradients
if self.reduce_scatter:
assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled"
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled"
assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled"
# param flattened by groups
self.fp16_groups = []
self.fp16_groups_flat = []
#param partitioned by data parallel degree
#this will contain a list of equal sized tensors
#each of which will be updated by a different process
self.parallel_partitioned_fp16_groups = []
#a single 32-bit partition of the parallel partitioned parameters
#that this process will update
self.single_partition_of_fp32_groups = []
#param partition info
#These are the parameters in each group that will not be updated by this process directly
self.params_not_in_partition = []
#These are the parameters that will be updated by this process directly
self.params_in_partition = []
#Offset from the first paramter in the the self.params_in_partition
#the parameter boundaries may not align with partition boundaries
#so we need to keep track of the offset
self.first_offset = []
#number of elements per partition in each group
self.partition_size = []
partition_id = dist.get_rank(group=self.dp_process_group)
self.all_reduce_print = False
# padding on each partition for alignment purposes
self.groups_padding = []
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
self.fp16_groups.append(param_group['params'])
# Record padding required to align group to world size
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
padding = get_alignment_padding(self.fp16_groups[i],
self.partition_count)
else:
padding = 0
self.groups_padding.append(padding)
#not sure why apex was cloning the weights before flattening
#removing cloning here
see_memory_usage(f"Before moving param group {i} to CPU")
#move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.fp16_groups[i])
see_memory_usage(f"After moving param group {i} to CPU")
#create flat buffer in CPU and move to GPU
self.fp16_groups_flat.append(
flatten_dense_tensors_aligned(
self.fp16_groups[i],
dist.get_world_size(group=self.dp_process_group)).cuda(
torch.cuda.current_device()))
see_memory_usage(f"After flattening and moving param group {i} to GPU")
if dist.get_rank(group=self.dp_process_group) == 0:
see_memory_usage(
f"After Flattening and after emptying param group {i} cache")
# set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
#divide the flat weights into near equal paritition equal to the data parallel degree
#each process will compute on a different part of the partition
data_parallel_partitions = self.get_data_parallel_partitions(
self.fp16_groups_flat[i])
self.parallel_partitioned_fp16_groups.append(data_parallel_partitions)
# a partition of the fp32 master weights that will be updated by this process
self.single_partition_of_fp32_groups.append(
self.parallel_partitioned_fp16_groups[i]
[partition_id].clone().float().detach())
# modify optimizer of have flat master weight
self.single_partition_of_fp32_groups[
i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.single_partition_of_fp32_groups[i]]
partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(
group=self.dp_process_group)
params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(self.fp16_groups[i], partition_size, partition_id)
self.partition_size.append(partition_size)
self.params_in_partition.append(params_in_partition)
self.params_not_in_partition.append(params_not_in_partition)
self.first_offset.append(first_offset)
self.reduce_bucket_size = int(reduce_bucket_size)
self.allgather_bucket_size = int(allgather_bucket_size)
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.reduction_stream = torch.cuda.Stream()
self.callback_queued = False
self.param_dict = {}
#map between param_id and bool to specify if a param is in this partition
self.is_param_in_current_partition = {}
self.contiguous_gradients = contiguous_gradients
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.elements_in_ipg_bucket = 0
self.params_already_reduced = []
self._release_ipg_buffers()
self.previous_reduced_grads = None
#simplified param id
self.param_id = {}
count = 0
for i, params_group in enumerate(self.fp16_groups):
for param in params_group:
unique_id = id(param)
self.param_id[unique_id] = count
self.param_dict[count] = param
self.params_already_reduced.append(False)
count = count + 1
for param_group in self.params_in_partition:
for param in param_group:
self.is_param_in_current_partition[self.get_param_id(param)] = True
for param_group in self.params_not_in_partition:
for param in param_group:
self.is_param_in_current_partition[self.get_param_id(param)] = False
#mapping from parameter to partition that it belongs to
self.param_to_partition_ids = {}
#stores if a partition has been reduced in this step
self.is_partition_reduced = {}
#number of grads in partition that still need to be computed
self.remaining_grads_in_partition = {}
#total number of grads in partition
self.total_grads_in_partition = {}
#stores if a grad in a partition has been computed or not
self.is_grad_computed = {}
#stores the offset at which a parameter gradient needs to be inserted in a partition
self.grad_partition_insertion_offset = {}
#the offset in the gradient at which it must be inserted at the beginning of the paritition
self.grad_start_offset = {}
#will store the averaged gradients required by this parititon
self.averaged_gradients = {}
# store index of first parameter in each partition
self.first_param_index_in_partition = {}
#initializes all data structures for implementing gradient partitioning
self.initialize_gradient_partitioning_data_structures()
#resets the data structure value for the next backward propagation
self.reset_partition_gradient_structures()
#creates backward hooks for gradient partitioning
self.create_reduce_and_remove_grad_hooks()
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
if dynamic_loss_args is None:
self.loss_scaler = DynamicLossScaler()
else:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
self.dynamic_loss_scale = True
else:
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(scale=static_loss_scale)
self.cur_iter = 0
see_memory_usage("Before initializing optimizer states")
self.initialize_optimizer_states()
see_memory_usage("After initializing optimizer states")
if dist.get_rank() == 0:
logger.info(f"optimizer state initialized")
if dist.get_rank(group=self.dp_process_group) == 0:
see_memory_usage(f"After initializing ZeRO optimizer")
def _release_ipg_buffers(self):
if self.contiguous_gradients:
self.ipg_buffer = None
self.grads_in_partition = None
self.grads_in_partition_offset = 0
def initialize_optimizer_states(self):
for i, group in enumerate(self.fp16_groups):
single_grad_partition = torch.zeros(
int(self.partition_size[i]),
dtype=self.single_partition_of_fp32_groups[i].dtype,
device=torch.cuda.current_device())
self.single_partition_of_fp32_groups[i].grad = single_grad_partition
self.optimizer.step()
for group in self.single_partition_of_fp32_groups:
group.grad = None
return
#########################################################################
#########################ZeRO Partition Gradients########################
#########################################################################
def get_first_param_index(self, group_id, param_group, partition_id):
for index, param in enumerate(param_group):
param_id = self.get_param_id(param)
if partition_id in self.param_to_partition_ids[group_id][param_id]:
return index
return None
def initialize_gradient_partitioning_data_structures(self):
total_partitions = dist.get_world_size(group=self.dp_process_group)
for i, param_group in enumerate(self.fp16_groups):
self.param_to_partition_ids[i] = {}
self.is_partition_reduced[i] = {}
self.total_grads_in_partition[i] = {}
self.remaining_grads_in_partition[i] = {}
self.is_grad_computed[i] = {}
self.grad_partition_insertion_offset[i] = {}
self.grad_start_offset[i] = {}
self.first_param_index_in_partition[i] = {}
for partition_id in range(total_partitions):
self.is_grad_computed[i][partition_id] = {}
self.grad_partition_insertion_offset[i][partition_id] = {}
self.grad_start_offset[i][partition_id] = {}
self.total_grads_in_partition[i][partition_id] = 0
self.initialize_gradient_partition(i, param_group, partition_id)
self.is_partition_reduced[i][partition_id] = False
self.first_param_index_in_partition[i][
partition_id] = self.get_first_param_index(
i,
param_group,
partition_id)
def independent_gradient_partition_epilogue(self):
self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0)
self.reduce_ipg_grads()
self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0)
#if dist.get_rank() == 0:
# logger.info("Params already reduced %s", self.params_already_reduced)
for i in range(len(self.params_already_reduced)):
self.params_already_reduced[i] = False
if self.overlap_comm:
torch.cuda.synchronize()
for i, _ in enumerate(self.fp16_groups):
if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
self.averaged_gradients[i] = self.get_flat_partition(
self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=torch.half,
device=torch.cuda.current_device(),
return_tensor_list=True)
else:
#When gradient accumulation is greater that 1
#This code path will be triggered and will add
#to the accumulated averaged gradients
avg_new = self.get_flat_partition(self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=torch.half,
device=torch.cuda.current_device(),
return_tensor_list=True)
for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i],avg_new):
accumulated_grad.add_(new_avg_grad)
self._release_ipg_buffers()
# No need to keep the gradients anymore.
# All gradients required by the step
# are in self.averaged_gradients
self.zero_grad()
see_memory_usage(f"End ipg_epilogue")
# resets all partition to no reduced
# sets remianing grads to the total number of grads in each partition
# set is grad computed to false for all grads in partition
def reset_partition_gradient_structures(self):
total_partitions = dist.get_world_size(group=self.dp_process_group)
for i, _ in enumerate(self.fp16_groups):
for partition_id in range(total_partitions):
self.is_partition_reduced[i][partition_id] = False
self.remaining_grads_in_partition[i][
partition_id] = self.total_grads_in_partition[i][partition_id]
for param_id in self.is_grad_computed[i][partition_id]:
self.is_grad_computed[i][partition_id][param_id] = False
def initialize_gradient_partition(self, i, param_group, partition_id):
def set_key_value_list(dictionary, key, value):
if key in dictionary:
dictionary[key].append(value)
else:
dictionary[key] = [value]
def increment_value(dictionary, key):
if key in dictionary:
dictionary[key] += 1
else:
dictionary[key] = 1
partition_size = self.partition_size[i]
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for param in param_group:
param_size = param.numel()
param_id = self.get_param_id(param)
if (current_index >= start_index and current_index < end_index):
set_key_value_list(self.param_to_partition_ids[i],
param_id,
partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][
param_id] = current_index - start_index
self.grad_start_offset[i][partition_id][param_id] = 0
elif start_index > current_index and start_index < (current_index +
param_size):
assert (first_offset==0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
set_key_value_list(self.param_to_partition_ids[i],
param_id,
partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][param_id] = 0
self.grad_start_offset[i][partition_id][param_id] = first_offset
current_index = current_index + param_size
def overlapping_partition_gradients_reduce_epilogue(self):
self.independent_gradient_partition_epilogue()
def create_reduce_and_remove_grad_hooks(self):
self.grad_accs = []
for i, param_group in enumerate(self.fp16_groups):
for param in param_group:
if param.requires_grad:
def wrapper(param, i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param, i)
grad_acc.register_hook(reduce_partition_and_remove_grads)
self.grad_accs.append(grad_acc)
wrapper(param, i)
def get_param_id(self, param):
unique_id = id(param)
return self.param_id[unique_id]
def report_ipg_memory_usage(self, tag, param_elems):
elem_count = self.elements_in_ipg_bucket + param_elems
percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size
see_memory_usage(
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
)
###############Idependent Partition Gradient ########################
def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads",
param.numel())
self.reduce_ipg_grads()
if self.contiguous_gradients and self.overlap_comm:
# Swap ipg_index between 0 and 1
self.ipg_index = 1 - self.ipg_index
self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads",
param.numel())
param_id = self.get_param_id(param)
assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported"
#keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
if self.contiguous_gradients:
new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(
0,
self.elements_in_ipg_bucket,
param.numel())
new_grad_tensor.copy_(param.grad.view(-1))
param.grad.data = new_grad_tensor.data.view_as(param.grad)
self.elements_in_ipg_bucket += param.numel()
self.grads_in_ipg_bucket.append(param.grad)
self.params_in_ipg_bucket.append((i, param, param_id))
self.report_ipg_memory_usage("End ipg_remove_grads", 0)
def print_rank_0(self, message):
if dist.get_rank() == 0:
logger.info(message)
def gradient_reduction_w_predivide(self, tensor):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
tensor_to_allreduce = tensor
if self.allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
if self.postscale_gradients:
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.gradient_predivide_factor() != dp_world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor() /
dp_world_size)
else:
tensor_to_allreduce.div_(dp_world_size)
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
return tensor
def average_tensor(self, tensor):
if self.overlap_comm:
torch.cuda.synchronize()
stream = self.reduction_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
if not self.reduce_scatter:
self.gradient_reduction_w_predivide(tensor)
return
# Accumulate destination ranks and bucket offsets for each gradient slice.
# Note: potential future optimization, record access pattern of parameters
# in backward pass and partition gradients w.r.t. access pattern so that our
# bucket is guaranteed to be contiguous w.r.t. ranks
rank_and_offsets = []
curr_size = 0
prev_id = -1
for i, param, param_id in self.params_in_ipg_bucket:
partition_ids = self.param_to_partition_ids[i][param_id]
partition_size = self.partition_size[i]
# Get all partition ids + their offsets
partition_ids_w_offsets = []
for partition_id in partition_ids:
offset = self.grad_start_offset[i][partition_id][param_id]
partition_ids_w_offsets.append((partition_id, offset))
partition_ids_w_offsets.sort(key=lambda t: t[1])
# Calculate rank and offsets for grad slices
for idx in range(len(partition_ids_w_offsets)):
partition_id, offset = partition_ids_w_offsets[idx]
# Calculate numel for grad slice depending on partition location
if idx == len(partition_ids_w_offsets) - 1:
# Last partition_id uses its own offset
numel = param.numel() - offset
else:
# Set numel to next partition's offset
numel = partition_ids_w_offsets[idx + 1][1] - offset
# Merge bucket ranges if they belong to the same rank
if partition_id == prev_id:
prev_pid, prev_size, prev_numel = rank_and_offsets[-1]
rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + numel)
else:
rank_and_offsets.append((partition_id, curr_size, numel))
curr_size += numel
prev_id = partition_id
tensor.div_(dist.get_world_size(group=self.dp_process_group))
async_handles = []
for dst, bucket_offset, numel in rank_and_offsets:
grad_slice = tensor.narrow(0, int(bucket_offset), int(numel))
dst_rank = _get_global_rank(self.dp_process_group, dst)
async_handle = dist.reduce(grad_slice,
dst=dst_rank,
group=self.dp_process_group,
async_op=True)
async_handles.append(async_handle)
for handle in async_handles:
handle.wait()
def copy_grads_in_partition(self, param):
if self.grads_in_partition is None:
self.grads_in_partition_offset = 0
total_size = 0
for group in self.params_in_partition:
for param_in_partition in group:
total_size += param_in_partition.numel()
see_memory_usage(f"before copying {total_size} gradients into partition")
self.grads_in_partition = torch.empty(int(total_size),
dtype=torch.half,
device=torch.cuda.current_device())
see_memory_usage(f"after copying {total_size} gradients into partition")
#The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer
new_grad_tensor = self.grads_in_partition.narrow(0,
self.grads_in_partition_offset,
param.numel())
new_grad_tensor.copy_(param.grad.view(-1))
param.grad.data = new_grad_tensor.data.view_as(param.grad)
self.grads_in_partition_offset += param.numel()
def reduce_ipg_grads(self):
if self.overlap_comm:
stream = self.reduction_stream
else:
stream = torch.cuda.current_stream()
if self.contiguous_gradients:
self.average_tensor(self.ipg_buffer[self.ipg_index])
else:
self.buffered_reduce_fallback(
None,
self.grads_in_ipg_bucket,
elements_per_buffer=self.elements_in_ipg_bucket)
with torch.cuda.stream(stream):
for _, param, param_id in self.params_in_ipg_bucket:
self.params_already_reduced[param_id] = True
if not self.is_param_in_current_partition[param_id]:
if self.overlap_comm and self.contiguous_gradients is False:
# Clear the previous grads during the next reduction
# to avoid clearing them before the reduction is complete.
if self.previous_reduced_grads is None:
self.previous_reduced_grads = []
self.previous_reduced_grads.append(param)
else:
param.grad = None
elif self.contiguous_gradients:
self.copy_grads_in_partition(param)
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.elements_in_ipg_bucket = 0
#####################################################################
def reduce_ready_partitions_and_remove_grads(self, param, i):
self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
def zero_reduced_gradients(self, partition_id, i):
def are_all_related_partitions_reduced(params_id):
for partition_id in self.param_to_partition_ids[i][params_id]:
if not self.is_partition_reduced[i][partition_id]:
return False
return True
for params_id in self.is_grad_computed[i][partition_id]:
if are_all_related_partitions_reduced(params_id):
self.param_dict[params_id].grad = None
def flatten_and_print(self, message, tensors, start=0, n=5):
flatten_tensor = _flatten_dense_tensors(tensors)
def print_func():
logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n))
self.sequential_execution(print_func, message)
def get_grads_to_reduce(self, i, partition_id):
def get_reducable_portion(key):
grad = self.param_dict[key].grad
total_elements = grad.numel()
start = self.grad_start_offset[i][partition_id][key]
num_elements = min(
total_elements - start,
self.partition_size[i] -
self.grad_partition_insertion_offset[i][partition_id][key])
if not pg_correctness_test:
if num_elements == total_elements:
return grad
else:
return grad.contiguous().view(-1).narrow(0,
int(start),
int(num_elements))
else:
if num_elements == total_elements:
return grad.clone()
else:
return grad.clone().contiguous().view(-1).narrow(
0,
int(start),
int(num_elements))
grads_to_reduce = []
for key in self.is_grad_computed[i][partition_id]:
grad = get_reducable_portion(key)
grads_to_reduce.append(grad)
return grads_to_reduce
def sequential_execution(self, function, message, group=None):
if group is None:
group = self.dp_process_group
if dist.get_rank(group=group) == 0:
logger.info(message)
for id in range(dist.get_world_size(group=group)):
if id == dist.get_rank(group=group):
function()
dist.barrier(group=group)
def set_none_gradients_to_zero(self, i, partition_id):
for param_id in self.is_grad_computed[i][partition_id]:
param = self.param_dict[param_id]
if param.grad is None:
param.grad = torch.zero_like(param)
######################Reduction Related Methods##############################
def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None):
rank = None
tensor = flatten(bucket)
tensor_to_allreduce = tensor
if pg_correctness_test:
allreduce_always_fp32 = True
if allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group))
if rank is None:
# "All Reducing"
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
else:
global_rank = _get_global_rank(self.dp_process_group, rank)
dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group)
if allreduce_always_fp32 and tensor is not tensor_to_allreduce:
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
tensor.copy_(tensor_to_allreduce)
return tensor
#if rank is specified do a reduction instead of an allreduce
def allreduce_and_copy(self, small_bucket, rank=None, log=None):
if self.overlap_comm:
torch.cuda.synchronize()
if self.previous_reduced_grads is not None:
# previous_reduced_grads has the previous reduced grads,
# now it is safe to clear.
for param in self.previous_reduced_grads:
param.grad = None
self.previous_reduced_grads = None
stream = self.reduction_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log)
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)):
buf.copy_(synced)
def allreduce_no_retain(self,
bucket,
numel_per_bucket=500000000,
rank=None,
log=None):
small_bucket = []
numel = 0
for tensor in bucket:
small_bucket.append(tensor)
numel = numel + tensor.numel()
if numel > numel_per_bucket:
self.allreduce_and_copy(small_bucket, rank=rank, log=None)
small_bucket = []
if len(small_bucket) > 0:
self.allreduce_and_copy(small_bucket, rank=rank, log=log)
#allows using reduction of gradients instead of using all_reduce
def buffered_reduce_fallback(self,
rank,
grads,
elements_per_buffer=500000000,
log=None):
split_buckets = split_half_float_double(grads)
for i, bucket in enumerate(split_buckets):
self.allreduce_no_retain(bucket,
numel_per_bucket=elements_per_buffer,
rank=rank,
log=log)
#############################################################################
#############################################################################
#############################################################################
#views the tensor as multiple partitions and returns
#those partitions
def get_data_parallel_partitions(self, tensor):
partitions = []
dp = dist.get_world_size(group=self.dp_process_group)
dp_id = dist.get_rank(group=self.dp_process_group)
total_num_elements = tensor.numel()
base_size = total_num_elements // dp
remaining = total_num_elements % dp
start = 0
for id in range(dp):
partition_size = base_size
if id < remaining:
partition_size = partition_size + 1
partitions.append(tensor.narrow(0, start, partition_size))
start = start + partition_size
return partitions
def get_partition_info(self, tensor_list, partition_size, partition_id):
params_in_partition = []
params_not_in_partition = []
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for tensor in tensor_list:
tensor_size = tensor.numel()
if (current_index >= start_index and current_index < end_index):
params_in_partition.append(tensor)
elif start_index > current_index and start_index < (current_index +
tensor_size):
params_in_partition.append(tensor)
assert (first_offset==0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
else:
params_not_in_partition.append(tensor)
current_index = current_index + tensor_size
return params_in_partition, params_not_in_partition, first_offset
def zero_grad(self, set_grads_to_None=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def _model_parallel_all_reduce(self, tensor, op):
""" Perform all reduce within model parallel group, if any.
"""
if self.model_parallel_group is None:
torch.distributed.all_reduce(tensor=tensor, op=op)
else:
torch.distributed.all_reduce(tensor=tensor,
op=op,
group=self.model_parallel_group)
def get_grad_norm_direct(self, gradients, params, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=self.dp_process_group)
# Take max across all GPUs.
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
#if dist.get_rank() == 0:
# logger.info(f"Total Norm begining {total_norm}")
for g, p in zip(gradients, params):
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
#creates a flat fused tensor from the tensor list starting at the first_offset
#in the first tensor of the list. If there are not enough elements in the tensor
#list then the flat tensor will be padded with zeros
def get_flat_partition(self,
tensor_list,
first_offset,
partition_size,
dtype,
device,
return_tensor_list=False):
flat_tensor_list = []
current_size = 0
for i, tensor in enumerate(tensor_list):
if tensor.grad is None:
continue
tensor = tensor.grad
num_elements = tensor.numel()
tensor_offset = 0
#we need to offset to get to the right element
if i == 0 and first_offset > 0:
tensor_offset = first_offset
num_elements = num_elements - tensor_offset
#we dont need all elements of the tensor
if num_elements > (partition_size - current_size):
num_elements = partition_size - current_size
#we need a narrow view of the tensor based on the tensor offset and number of elements that
#we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel():
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
0,
int(tensor_offset),
int(num_elements)))
else:
flat_tensor_list.append(tensor)
current_size = current_size + num_elements
#this means its the last partition and does not align with the dp boundary. We need to pad before flattening
if current_size < partition_size:
flat_tensor_list.append(
torch.zeros(int(partition_size - current_size),
dtype=dtype,
device=device))
if return_tensor_list:
return flat_tensor_list
return _flatten_dense_tensors(flat_tensor_list)
def free_grad_in_param_list(self, param_list):
for p in param_list:
p.grad = None
def step(self, closure=None):
"""
Not supporting closure.
"""
see_memory_usage(f"In step before checking overflow")
# First compute norm for all group so we know if there is overflow
self.check_overflow()
timers = self.timers
prev_scale = self.loss_scale
self._update_scale(self.overflow)
if self.overflow:
see_memory_usage('After overflow before clearing gradients')
self.zero_grad()
for key in self.averaged_gradients:
self.averaged_gradients[key] = None
see_memory_usage('After overflow after clearing gradients')
logger.info(
"[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(dist.get_rank(),
prev_scale,
self.loss_scale))
timers('optimizer_step').start()
timers('optimizer_step').stop()
timers('optimizer_allgather').start()
timers('optimizer_allgather').stop()
return
norm_groups = []
single_partition_grad_groups = []
skip = False
partition_id = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups):
norm_groups.append(
self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i]))
#free gradients for all the prameters that are not updated by this process
self.free_grad_in_param_list(self.params_not_in_partition[i])
#create a flat gradients for parameters updated by this process
# If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
single_grad_partition = flatten_dense_tensors_aligned(
self.averaged_gradients[i],
int(self.partition_size[i])).to(
self.single_partition_of_fp32_groups[i].dtype)
else:
single_grad_partition = _flatten_dense_tensors(
self.averaged_gradients[i]).to(
self.single_partition_of_fp32_groups[i].dtype)
assert single_grad_partition.numel() == self.partition_size[i], \
"averaged gradients have different number of elements that partition size {} {} {} {}".format(single_grad_partition.numel(), self.partition_size[i], i, partition_id)
self.single_partition_of_fp32_groups[i].grad = single_grad_partition
#release all the gradient since we have already created a necessary copy in dp_grad_partition
self.free_grad_in_param_list(self.params_in_partition[i])
self.averaged_gradients[i] = None
single_partition_grad_groups.append(single_grad_partition)
self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
timers('optimizer_step').start()
self.optimizer.step()
#get rid of the fp32 gradients. Not needed anymore
for group in self.single_partition_of_fp32_groups:
group.grad = None
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)
timers('optimizer_step').stop()
timers('optimizer_allgather').start()
#gather the updated weights from everyone
for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups):
#Sequential AllGather Best of both worlds
dp_world_size = dist.get_world_size(group=self.dp_process_group)
num_shards = max(
1,
partitioned_params[partition_id].numel() * dp_world_size //
self.allgather_bucket_size)
shard_size = partitioned_params[partition_id].numel() // num_shards
num_elements = shard_size
assert shard_size * num_shards <= partitioned_params[partition_id].numel()
for shard_id in range(num_shards):
if shard_id == (num_shards - 1):
num_elements = partitioned_params[partition_id].numel(
) - shard_id * shard_size
shard_list = []
for dp_id in range(dp_world_size):
curr_shard = partitioned_params[dp_id].narrow(
0,
shard_id * shard_size,
num_elements).detach()
shard_list.append(curr_shard)
dist.all_gather(shard_list,
shard_list[partition_id],
group=self.dp_process_group)
timers('optimizer_allgather').stop()
# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
see_memory_usage('After zero_optimizer step')
return
def unscale_and_clip_grads(self, grad_groups_flat, norm_groups):
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)
# compute combined scale factor for this group
combined_scale = self.loss_scale
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.loss_scale
for grad in grad_groups_flat:
if isinstance(grad, list):
sub_partitions = grad
for g in sub_partitions:
g.data.mul_(1. / combined_scale)
else:
grad.data.mul_(1. / combined_scale)
def _check_overflow(self, partition_gradients=True):
self.overflow = self.has_overflow(partition_gradients)
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params, is_grad_list=False):
for p in params:
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
return True
return False
def has_overflow_partitioned_grads_serial(self):
for i in range(len(self.fp16_groups)):
for j, grad in enumerate(self.averaged_gradients[i]):
if grad is not None and self._has_inf_or_nan(grad.data, j):
return True
return False
def has_overflow(self, partition_gradients=True):
if partition_gradients:
overflow = self.has_overflow_partitioned_grads_serial()
overflow_gpu = torch.cuda.ByteTensor([overflow])
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=self.dp_process_group)
else:
params = []
for group in self.fp16_groups:
for param in group:
params.append(param)
overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients)
overflow_gpu = torch.cuda.ByteTensor([overflow])
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
self._model_parallel_all_reduce(tensor=overflow_gpu,
op=torch.distributed.ReduceOp.MAX)
overflow = overflow_gpu[0].item()
return bool(overflow)
# `x` is a torch.Tensor
@staticmethod
def _has_inf_or_nan(x, j=None):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
return False
def backward(self, loss, retain_graph=False):
"""
:attr:`backward` performs the following steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
if self.contiguous_gradients:
self.ipg_buffer = []
buf_0 = torch.empty(self.reduce_bucket_size,
dtype=torch.half,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_0)
# Use double buffers to avoid data access conflict when overlap_comm is enabled.
if self.overlap_comm:
buf_1 = torch.empty(self.reduce_bucket_size,
dtype=torch.half,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_1)
self.ipg_index = 0
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
def check_overflow(self, partition_gradients=True):
self._check_overflow(partition_gradients)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
# Return group tensor after removing paddings that are added for alignment to DP world size.
# This method works on the assumption that each group contains a single flattened tensor.
def _get_groups_without_padding(self, groups_with_padding):
groups_without_padding = []
for i, group in enumerate(groups_with_padding):
lean_length = group.numel() - self.groups_padding[i]
groups_without_padding.append(group[:lean_length])
return groups_without_padding
# Return optimizer state after removing paddings that are added for alignment.
def _get_state_without_padding(self, state_with_padding, padding):
lean_state = {}
for key, value in state_with_padding.items():
lean_length = value.numel() - padding
lean_state[key] = value[:lean_length]
return lean_state
# Return base optimizer states.
# This method assumes that each param group contains a single flattened tensor.
def _get_base_optimizer_state(self):
optimizer_groups_state = []
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
lean_optimizer_state = self._get_state_without_padding(
self.optimizer.state[p],
self.groups_padding[i])
optimizer_groups_state.append(lean_optimizer_state)
return optimizer_groups_state
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['base_optimizer_state'] = self._get_base_optimizer_state()
state_dict['zero_stage'] = ZERO_OPTIMIZATION_GRADIENTS
state_dict['partition_count'] = self.partition_count
# Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding(
self.single_partition_of_fp32_groups)
state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding
return state_dict
# Restore base optimizer fp32 weights from checkpoint by:
# 1) Merging fp32 weights from checkpoints of all partitions
# 2) Extracting fp32 weights for current partition from merged weights
# 3) Using extracted weights to update base optimizer weights directly.
def _restore_from_fp32_weights(self, all_state_dict):
partition_id = dist.get_rank(group=self.dp_process_group)
merged_single_partition_of_fp32_groups = []
for i in range(len(self.single_partition_of_fp32_groups)):
merged_partitions = [
sd['single_partition_of_fp32_groups'][i] for sd in all_state_dict
]
flat_merged_partitions = flatten_dense_tensors_aligned(
merged_partitions,
dist.get_world_size(group=self.dp_process_group))
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions)
merged_single_partition_of_fp32_groups.append(dp_partitions[partition_id])
for current, saved in zip(self.single_partition_of_fp32_groups, merged_single_partition_of_fp32_groups):
current.data.copy_(saved.data)
# Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self):
partition_id = dist.get_rank(group=self.dp_process_group)
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp32_partition.data.copy_(fp16_partitions[partition_id].data)
# Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
self._restore_from_fp16_weights()
# Extract optimizer state for current partition from merged states of all partitions
def _partition_base_optimizer_state(self, state_key, all_partition_states):
partition_id = dist.get_rank(group=self.dp_process_group)
alignment = dist.get_world_size(group=self.dp_process_group)
flat_merged_partitions = flatten_dense_tensors_aligned(
all_partition_states,
alignment)
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions)
return dp_partitions[partition_id]
# Restore base optimizer state from checkpoint by
# 1) Merging optimizer state from checkpoints of all partitions
# 2) Extracting optimizer state for current partition from the merged state
# 3) Using the extracted value to directly update the base optimizer.
def _restore_base_optimizer_state(self, all_state_dict):
base_optimizer_group_states = []
for i in range(len(self.optimizer.param_groups)):
partition_states = {}
all_partition_group_states = [
sd['base_optimizer_state'][i] for sd in all_state_dict
]
for key in all_partition_group_states[0].keys():
all_partition_states = [
all_states[key] for all_states in all_partition_group_states
]
partition_states[key] = self._partition_base_optimizer_state(
key,
all_partition_states)
base_optimizer_group_states.append(partition_states)
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
for key, saved in base_optimizer_group_states[i].items():
current = self.optimizer.state[p][key]
current.data.copy_(saved.data)
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
r"""Loading ZeRO checkpoint
Arguments:
state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
Note that the number of saved partitions may differ from number of loading partitions to support
changing GPU count, specifically DP world size, between saving and loading checkpoints.
load_optimizer_states: Boolean indicating whether or not to load base optimizer states
load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32
copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss).
"""
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict_list[0]['loss_scaler']
self.dynamic_loss_scale = state_dict_list[0]['dynamic_loss_scale']
self.overflow = state_dict_list[0]['overflow']
if load_optimizer_states:
self._restore_base_optimizer_state(state_dict_list)
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 1 if changing DP degree and option 2 otherwise.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
if load_from_fp32_weights:
self._restore_from_fp32_weights(state_dict_list)
else:
self._restore_from_fp16_weights()
def _handle_overflow(cpu_sum, x, i):
import math
rank = torch.distributed.get_rank()
if rank == 0:
t_i = -1
for v_i, v in enumerate(x.data.contiguous().view(-1)):
if not math.isfinite(float(v)):
t_i = v_i
break
logger.info(
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
)
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed.distributed_c10d import _get_global_rank
import torch.distributed as dist
import math
from torch._six import inf
from torch.autograd import Variable
import collections
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.utils import logger
#Toggle this to true to enable correctness test
#with gradient partitioning and without
pg_correctness_test = False
try:
from apex_C import flatten
from apex_C import unflatten
except ImportError:
try:
_ = warned_flatten
except NameError:
logger.warning(
"apex was installed without --cpp_ext. Falling back to Python flatten and unflatten."
)
warned_flatten = True
from torch._utils import _flatten_dense_tensors as flatten
from torch._utils import _unflatten_dense_tensors as unflatten
def input(msg):
return
def split_half_float_double(tensors):
dtypes = [
"torch.cuda.HalfTensor",
"torch.cuda.FloatTensor",
"torch.cuda.DoubleTensor"
]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
return buckets
def isclose(a, b, rtol=1e-09, atol=0.0):
return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)
def lcm(x, y):
from fractions import gcd # or can import gcd from `math` in Python 3
return x * y // gcd(x, y)
# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(tensor_list, alignment):
num_elements = 0
for tensor in tensor_list:
num_elements = num_elements + tensor.numel()
remaining = num_elements % alignment
if remaining:
elements_to_add = alignment - remaining
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
num_elements = num_elements + elements_to_add
else:
padded_tensor_list = tensor_list
return _flatten_dense_tensors(padded_tensor_list)
def get_alignment_padding(tensor_list, alignment):
num_elements = sum([tensor.numel() for tensor in tensor_list])
remainder = num_elements % alignment
return (alignment - remainder) if remainder else remainder
def move_to_cpu(tensor_list):
for tensor in tensor_list:
tensor.data = tensor.data.cpu()
def print_rank_msg(msg):
print(f"rank {dist.get_rank()} - {msg}")
class FP16_DeepSpeedZeroOptimizer(object):
"""
DeepSpeedZeroOptimizer designed to reduce the memory footprint
required for training large deep learning models.
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
For usage examples, refer to TODO: DeepSpeed Tutorial
"""
def __init__(self,
init_optimizer,
timers,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
contiguous_gradients=True,
reduce_bucket_size=500000000,
allgather_bucket_size=5000000000,
dp_process_group=None,
reduce_scatter=True,
overlap_comm=False,
cpu_offload=False,
mpu=None,
clip_grad=0.0,
allreduce_always_fp32=False,
postscale_gradients=True,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1):
if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
logger.info(f"Allgather bucket size {allgather_bucket_size}")
logger.info(f"CPU Offload: {cpu_offload}")
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add ne552w fused optimizer later
# differences from apex.fp16_utils:
# - assume all model params in fp16
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
self.timers = timers
self.reduce_scatter = reduce_scatter
self.overlap_comm = overlap_comm
self.cpu_offload = cpu_offload
self.deepspeed_adam_offload = (cpu_offload
and type(init_optimizer) == DeepSpeedCPUAdam)
self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu'
self.dp_process_group = dp_process_group
self.partition_count = dist.get_world_size(group=self.dp_process_group)
if mpu is None:
self.model_parallel_group = None
self.model_parallel_rank = 0
else:
self.model_parallel_group = mpu.get_model_parallel_group()
self.model_parallel_rank = mpu.get_model_parallel_rank()
self.overflow = False
self.clip_grad = clip_grad
self.allreduce_always_fp32 = allreduce_always_fp32
self.gradient_predivide_factor = gradient_predivide_factor
self.postscale_gradients = postscale_gradients
self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = 0
if self.reduce_scatter:
assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled"
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled"
assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled"
# param flattened by groups
self.fp16_groups = []
self.fp16_groups_flat = []
#param partitioned by data parallel degree
#this will contain a list of equal sized tensors
#each of which will be updated by a different process
self.parallel_partitioned_fp16_groups = []
#a single 32-bit partition of the parallel partitioned parameters
#that this process will update
self.single_partition_of_fp32_groups = []
#param partition info
#These are the parameters in each group that will not be updated by this process directly
self.params_not_in_partition = []
#These are the parameters that will be updated by this process directly
self.params_in_partition = []
#Offset from the first paramter in the the self.params_in_partition
#the parameter boundaries may not align with partition boundaries
#so we need to keep track of the offset
self.first_offset = []
#number of elements per partition in each group
self.partition_size = []
partition_id = dist.get_rank(group=self.dp_process_group)
self.all_reduce_print = False
# padding on each partition for alignment purposes
self.groups_padding = []
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
self.fp16_groups.append(param_group['params'])
# Record padding required to align group to world size
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
padding = get_alignment_padding(self.fp16_groups[i],
self.partition_count)
else:
padding = 0
self.groups_padding.append(padding)
#not sure why apex was cloning the weights before flattening
#removing cloning here
see_memory_usage(f"Before moving param group {i} to CPU")
#move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.fp16_groups[i])
see_memory_usage(f"After moving param group {i} to CPU")
#create flat buffer in CPU and move to GPU
self.fp16_groups_flat.append(
flatten_dense_tensors_aligned(
self.fp16_groups[i],
dist.get_world_size(group=self.dp_process_group)).cuda(
torch.cuda.current_device()))
see_memory_usage(f"After flattening and moving param group {i} to GPU")
if dist.get_rank(group=self.dp_process_group) == 0:
see_memory_usage(
f"After Flattening and after emptying param group {i} cache")
# set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
#divide the flat weights into near equal paritition equal to the data parallel degree
#each process will compute on a different part of the partition
data_parallel_partitions = self.get_data_parallel_partitions(
self.fp16_groups_flat[i])
self.parallel_partitioned_fp16_groups.append(data_parallel_partitions)
# a partition of the fp32 master weights that will be updated by this process
self.single_partition_of_fp32_groups.append(
self.parallel_partitioned_fp16_groups[i][partition_id].to(
self.device).clone().float().detach())
# modify optimizer of have flat master weight
self.single_partition_of_fp32_groups[
i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.single_partition_of_fp32_groups[i]]
partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(
group=self.dp_process_group)
params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(self.fp16_groups[i], partition_size, partition_id)
self.partition_size.append(partition_size)
self.params_in_partition.append(params_in_partition)
self.params_not_in_partition.append(params_not_in_partition)
self.first_offset.append(first_offset)
self.reduce_bucket_size = int(reduce_bucket_size)
self.allgather_bucket_size = int(allgather_bucket_size)
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.reduction_stream = torch.cuda.Stream()
self.cpu_computation_stream = torch.cuda.Stream()
self.migration_stream = torch.cuda.Stream()
self.callback_queued = False
self.param_dict = {}
#map between param_id and bool to specify if a param is in this partition
self.is_param_in_current_partition = {}
# CPU-Offload requires contiguous gradients
self.contiguous_gradients = contiguous_gradients or cpu_offload
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.elements_in_ipg_bucket = 0
self.params_already_reduced = []
self._release_ipg_buffers()
self.previous_reduced_grads = None
#simplified param id
self.param_id = {}
largest_param_numel = 0
count = 0
for i, params_group in enumerate(self.fp16_groups):
for param in params_group:
unique_id = id(param)
self.param_id[unique_id] = count
self.param_dict[count] = param
self.params_already_reduced.append(False)
if param.numel() > largest_param_numel:
largest_param_numel = param.numel()
count = count + 1
for param_group in self.params_in_partition:
for param in param_group:
self.is_param_in_current_partition[self.get_param_id(param)] = True
for param_group in self.params_not_in_partition:
for param in param_group:
self.is_param_in_current_partition[self.get_param_id(param)] = False
if self.cpu_offload:
self.accumulated_grads_in_cpu = {}
self.norm_for_param_grads = {}
self.local_overflow = False
self.grad_position = {}
self.temp_grad_buffer_for_cpu_offload = torch.zeros(
largest_param_numel,
device=self.device).half().pin_memory()
self.temp_grad_buffer_for_gpu_offload = torch.zeros(
largest_param_numel,
device=torch.cuda.current_device()).half()
for i, params_group in enumerate(self.fp16_groups):
self.get_grad_position(i,
self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i])
#mapping from parameter to partition that it belongs to
self.param_to_partition_ids = {}
#stores if a partition has been reduced in this step
self.is_partition_reduced = {}
#number of grads in partition that still need to be computed
self.remaining_grads_in_partition = {}
#total number of grads in partition
self.total_grads_in_partition = {}
#stores if a grad in a partition has been computed or not
self.is_grad_computed = {}
#stores the offset at which a parameter gradient needs to be inserted in a partition
self.grad_partition_insertion_offset = {}
#the offset in the gradient at which it must be inserted at the beginning of the paritition
self.grad_start_offset = {}
#will store the averaged gradients required by this parititon
self.averaged_gradients = {}
# store index of first parameter in each partition
self.first_param_index_in_partition = {}
#initializes all data structures for implementing gradient partitioning
self.initialize_gradient_partitioning_data_structures()
#resets the data structure value for the next backward propagation
self.reset_partition_gradient_structures()
#creates backward hooks for gradient partitioning
self.create_reduce_and_remove_grad_hooks()
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
if dynamic_loss_args is None:
self.loss_scaler = DynamicLossScaler()
else:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
self.dynamic_loss_scale = True
else:
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(scale=static_loss_scale)
self.cur_iter = 0
see_memory_usage("Before initializing optimizer states")
self.initialize_optimizer_states()
see_memory_usage("After initializing optimizer states")
if dist.get_rank() == 0:
logger.info(f"optimizer state initialized")
if dist.get_rank(group=self.dp_process_group) == 0:
see_memory_usage(f"After initializing ZeRO optimizer")
def _release_ipg_buffers(self):
if self.contiguous_gradients:
self.ipg_buffer = None
self.grads_in_partition = None
self.grads_in_partition_offset = 0
def initialize_optimizer_states(self):
for i, group in enumerate(self.fp16_groups):
single_grad_partition = torch.zeros(
int(self.partition_size[i]),
dtype=self.single_partition_of_fp32_groups[i].dtype,
device=self.device)
self.single_partition_of_fp32_groups[
i].grad = single_grad_partition.pin_memory(
) if self.cpu_offload else single_grad_partition
self.optimizer.step()
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None
return
#########################################################################
#########################ZeRO Partition Gradients########################
#########################################################################
def get_first_param_index(self, group_id, param_group, partition_id):
for index, param in enumerate(param_group):
param_id = self.get_param_id(param)
if partition_id in self.param_to_partition_ids[group_id][param_id]:
return index
return None
def initialize_gradient_partitioning_data_structures(self):
total_partitions = dist.get_world_size(group=self.dp_process_group)
for i, param_group in enumerate(self.fp16_groups):
self.param_to_partition_ids[i] = {}
self.is_partition_reduced[i] = {}
self.total_grads_in_partition[i] = {}
self.remaining_grads_in_partition[i] = {}
self.is_grad_computed[i] = {}
self.grad_partition_insertion_offset[i] = {}
self.grad_start_offset[i] = {}
self.first_param_index_in_partition[i] = {}
for partition_id in range(total_partitions):
self.is_grad_computed[i][partition_id] = {}
self.grad_partition_insertion_offset[i][partition_id] = {}
self.grad_start_offset[i][partition_id] = {}
self.total_grads_in_partition[i][partition_id] = 0
self.initialize_gradient_partition(i, param_group, partition_id)
self.is_partition_reduced[i][partition_id] = False
self.first_param_index_in_partition[i][
partition_id] = self.get_first_param_index(
i,
param_group,
partition_id)
def independent_gradient_partition_epilogue(self):
self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0)
self.reduce_ipg_grads()
self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0)
#if dist.get_rank() == 0:
# logger.info("Params already reduced %s", self.params_already_reduced)
for i in range(len(self.params_already_reduced)):
self.params_already_reduced[i] = False
if self.overlap_comm:
torch.cuda.synchronize()
if self.cpu_offload is False:
for i, _ in enumerate(self.fp16_groups):
if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
self.averaged_gradients[i] = self.get_flat_partition(
self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=torch.half,
device=torch.cuda.current_device(),
return_tensor_list=True)
else:
avg_new = self.get_flat_partition(self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=torch.half,
device=torch.cuda.current_device(),
return_tensor_list=True)
for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i],avg_new):
accumulated_grad.add_(new_avg_grad)
self._release_ipg_buffers()
# No need to keep the gradients anymore.
# All gradients required by the step
# are in self.averaged_gradients
self.zero_grad()
see_memory_usage(f"End ipg_epilogue")
# resets all partition to no reduced
# sets remianing grads to the total number of grads in each partition
# set is grad computed to false for all grads in partition
def reset_partition_gradient_structures(self):
total_partitions = dist.get_world_size(group=self.dp_process_group)
for i, _ in enumerate(self.fp16_groups):
for partition_id in range(total_partitions):
self.is_partition_reduced[i][partition_id] = False
self.remaining_grads_in_partition[i][
partition_id] = self.total_grads_in_partition[i][partition_id]
for param_id in self.is_grad_computed[i][partition_id]:
self.is_grad_computed[i][partition_id][param_id] = False
def initialize_gradient_partition(self, i, param_group, partition_id):
def set_key_value_list(dictionary, key, value):
if key in dictionary:
dictionary[key].append(value)
else:
dictionary[key] = [value]
def increment_value(dictionary, key):
if key in dictionary:
dictionary[key] += 1
else:
dictionary[key] = 1
partition_size = self.partition_size[i]
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for param in param_group:
param_size = param.numel()
param_id = self.get_param_id(param)
if (current_index >= start_index and current_index < end_index):
set_key_value_list(self.param_to_partition_ids[i],
param_id,
partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][
param_id] = current_index - start_index
self.grad_start_offset[i][partition_id][param_id] = 0
elif start_index > current_index and start_index < (current_index +
param_size):
assert (first_offset==0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
set_key_value_list(self.param_to_partition_ids[i],
param_id,
partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][param_id] = 0
self.grad_start_offset[i][partition_id][param_id] = first_offset
current_index = current_index + param_size
def overlapping_partition_gradients_reduce_epilogue(self):
self.independent_gradient_partition_epilogue()
def create_reduce_and_remove_grad_hooks(self):
self.grad_accs = []
for i, param_group in enumerate(self.fp16_groups):
for param in param_group:
if param.requires_grad:
def wrapper(param, i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param, i)
grad_acc.register_hook(reduce_partition_and_remove_grads)
self.grad_accs.append(grad_acc)
wrapper(param, i)
def get_param_id(self, param):
unique_id = id(param)
return self.param_id[unique_id]
def report_ipg_memory_usage(self, tag, param_elems):
elem_count = self.elements_in_ipg_bucket + param_elems
percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size
see_memory_usage(
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
)
###############Idependent Partition Gradient ########################
def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads",
param.numel())
self.reduce_ipg_grads()
if self.contiguous_gradients and self.overlap_comm:
# Swap ipg_index between 0 and 1
self.ipg_index = 1 - self.ipg_index
self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads",
param.numel())
param_id = self.get_param_id(param)
assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported"
#keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
if self.contiguous_gradients:
new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(
0,
self.elements_in_ipg_bucket,
param.numel())
new_grad_tensor.copy_(param.grad.view(-1))
param.grad.data = new_grad_tensor.data.view_as(param.grad)
self.elements_in_ipg_bucket += param.numel()
self.grads_in_ipg_bucket.append(param.grad)
self.params_in_ipg_bucket.append((i, param, param_id))
self.report_ipg_memory_usage("End ipg_remove_grads", 0)
def print_rank_0(self, message):
if dist.get_rank() == 0:
logger.info(message)
def gradient_reduction_w_predivide(self, tensor):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
tensor_to_allreduce = tensor
if self.allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
if self.postscale_gradients:
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.gradient_predivide_factor != dp_world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor / dp_world_size)
else:
tensor_to_allreduce.div_(dp_world_size)
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
return tensor
def average_tensor(self, tensor):
if self.overlap_comm:
torch.cuda.synchronize()
stream = self.reduction_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
if not self.reduce_scatter:
self.gradient_reduction_w_predivide(tensor)
return
# Accumulate destination ranks and bucket offsets for each gradient slice.
# Note: potential future optimization, record access pattern of parameters
# in backward pass and partition gradients w.r.t. access pattern so that our
# bucket is guaranteed to be contiguous w.r.t. ranks
rank_and_offsets = []
curr_size = 0
prev_id = -1
for i, param, param_id in self.params_in_ipg_bucket:
partition_ids = self.param_to_partition_ids[i][param_id]
partition_size = self.partition_size[i]
# Get all partition ids + their offsets
partition_ids_w_offsets = []
for partition_id in partition_ids:
offset = self.grad_start_offset[i][partition_id][param_id]
partition_ids_w_offsets.append((partition_id, offset))
partition_ids_w_offsets.sort(key=lambda t: t[1])
# Calculate rank and offsets for grad slices
for idx in range(len(partition_ids_w_offsets)):
partition_id, offset = partition_ids_w_offsets[idx]
# Calculate numel for grad slice depending on partition location
if idx == len(partition_ids_w_offsets) - 1:
# Last partition_id uses its own offset
numel = param.numel() - offset
else:
# Set numel to next partition's offset
numel = partition_ids_w_offsets[idx + 1][1] - offset
# Merge bucket ranges if they belong to the same rank
if partition_id == prev_id:
prev_pid, prev_size, prev_numel = rank_and_offsets[-1]
rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + numel)
else:
rank_and_offsets.append((partition_id, curr_size, numel))
curr_size += numel
prev_id = partition_id
tensor.div_(dist.get_world_size(group=self.dp_process_group))
async_handles = []
for dst, bucket_offset, numel in rank_and_offsets:
grad_slice = tensor.narrow(0, int(bucket_offset), int(numel))
dst_rank = _get_global_rank(self.dp_process_group, dst)
async_handle = dist.reduce(grad_slice,
dst=dst_rank,
group=self.dp_process_group,
async_op=True)
async_handles.append(async_handle)
for handle in async_handles:
handle.wait()
##############################################################################
############################# CPU Offload Methods#############################
##############################################################################
def get_grad_position(self, group_id, tensor_list, first_offset, partition_size):
current_offset = 0
for i, tensor in enumerate(tensor_list):
param_id = self.get_param_id(tensor)
param_start_offset = 0
num_elements = tensor.numel()
tensor_offset = 0
#we need to offset to get to the right element
if i == 0 and first_offset > 0:
tensor_offset = first_offset
num_elements = num_elements - tensor_offset
param_start_offset = first_offset
#we dont need all elements of the tensor
if num_elements > (partition_size - current_offset):
num_elements = partition_size - current_offset
self.grad_position[param_id] = [
int(group_id),
int(param_start_offset),
int(current_offset),
int(num_elements)
]
current_offset += num_elements
def update_overflow_tracker_for_param_grad(self, param):
if param.grad is not None and self._has_inf_or_nan(param.grad.data):
self.local_overflow = True
def async_accumulate_grad_in_cpu(self, param):
param_id = self.get_param_id(param)
#copy to a preexisiting buffer to avoid memory allocation penalty
dest_buffer = self.temp_grad_buffer_for_cpu_offload.view(-1).narrow(
0,
0,
param.numel())
dest_buffer.copy_(param.grad.view(-1), non_blocking=True)
if param_id not in self.accumulated_grads_in_cpu:
self.accumulated_grads_in_cpu[param_id] = torch.zeros(
param.numel(),
dtype=param.dtype,
device=self.device).pin_memory()
self.accumulated_grads_in_cpu[param_id].add_(dest_buffer)
def async_accumulate_grad_in_cpu_via_gpu(self, param):
param_id = self.get_param_id(param)
#copy to a preexisiting buffer to avoid memory allocation penalty
dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(
0,
0,
param.numel())
if param_id not in self.accumulated_grads_in_cpu:
self.accumulated_grads_in_cpu[param_id] = torch.zeros(
param.numel(),
dtype=param.dtype,
device=self.device).pin_memory()
if self.micro_step_id > 0:
dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1),
non_blocking=True)
param.grad.data.view(-1).add_(dest_buffer)
#at the boundary we will send 32bit directly
if not self.is_gradient_accumulation_boundary:
self.accumulated_grads_in_cpu[param_id].data.copy_(param.grad.data.view(-1),
non_blocking=True)
def set_norm_for_param_grad(self, param):
param_id = self.get_param_id(param)
accumulated_grad = self.accumulated_grads_in_cpu[
param_id] if self.gradient_accumulation_steps > 1 else param.grad
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
start = source_offset
accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements)
self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2)
def set_norm_for_param_grad_in_gpu(self, param):
param_id = self.get_param_id(param)
accumulated_grad = param.grad
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
start = source_offset
accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements)
self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2)
def async_inplace_copy_grad_to_fp32_buffer(self, param):
param_id = self.get_param_id(param)
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(
0,
dest_offset,
num_elements)
if self.gradient_accumulation_steps > 1:
src_tensor = self.accumulated_grads_in_cpu[param_id].view(-1).narrow(
0,
source_offset,
num_elements)
else:
src_tensor = param.grad.view(-1).narrow(0,
source_offset,
num_elements).float()
dest_tensor.copy_(src_tensor, non_blocking=True)
def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param):
param_id = self.get_param_id(param)
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(
0,
dest_offset,
num_elements)
src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements).float()
dest_tensor.copy_(src_tensor, non_blocking=True)
param.grad = None
def complete_grad_norm_calculation_for_cpu_offload(self, params):
total_norm = 0.0
norm_type = 2.0
for p in params:
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_id = self.get_param_id(p)
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item()**2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
############################################################################################
def copy_grads_in_partition(self, param):
if self.cpu_offload:
if self.gradient_accumulation_steps > 1:
self.async_accumulate_grad_in_cpu_via_gpu(param)
if self.is_gradient_accumulation_boundary:
self.set_norm_for_param_grad_in_gpu(param)
self.update_overflow_tracker_for_param_grad(param)
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
return
#print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}")
if self.grads_in_partition is None:
self.grads_in_partition_offset = 0
total_size = 0
for group in self.params_in_partition:
for param_in_partition in group:
total_size += param_in_partition.numel()
see_memory_usage(f"before copying {total_size} gradients into partition")
self.grads_in_partition = torch.empty(int(total_size),
dtype=torch.half,
device=torch.cuda.current_device())
see_memory_usage(f"after copying {total_size} gradients into partition")
#The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer
new_grad_tensor = self.grads_in_partition.view(-1).narrow(
0,
self.grads_in_partition_offset,
param.numel())
new_grad_tensor.copy_(param.grad.view(-1))
param.grad.data = new_grad_tensor.data.view_as(param.grad)
#print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}")
self.grads_in_partition_offset += param.numel()
def reduce_ipg_grads(self):
if self.overlap_comm:
stream = self.reduction_stream
else:
stream = torch.cuda.current_stream()
if self.contiguous_gradients:
self.average_tensor(self.ipg_buffer[self.ipg_index])
else:
self.buffered_reduce_fallback(
None,
self.grads_in_ipg_bucket,
elements_per_buffer=self.elements_in_ipg_bucket)
with torch.cuda.stream(stream):
for _, param, param_id in self.params_in_ipg_bucket:
self.params_already_reduced[param_id] = True
if not self.is_param_in_current_partition[param_id]:
if self.overlap_comm and self.contiguous_gradients is False:
# Clear the previous grads during the next reduction
# to avoid clearing them before the reduction is complete.
if self.previous_reduced_grads is None:
self.previous_reduced_grads = []
self.previous_reduced_grads.append(param)
else:
param.grad = None
elif self.contiguous_gradients:
self.copy_grads_in_partition(param)
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.elements_in_ipg_bucket = 0
#####################################################################
def reduce_ready_partitions_and_remove_grads(self, param, i):
self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
def zero_reduced_gradients(self, partition_id, i):
def are_all_related_partitions_reduced(params_id):
for partition_id in self.param_to_partition_ids[i][params_id]:
if not self.is_partition_reduced[i][partition_id]:
return False
return True
for params_id in self.is_grad_computed[i][partition_id]:
if are_all_related_partitions_reduced(params_id):
self.param_dict[params_id].grad = None
def flatten_and_print(self, message, tensors, start=0, n=5):
flatten_tensor = _flatten_dense_tensors(tensors)
def print_func():
logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n))
self.sequential_execution(print_func, message)
def get_grads_to_reduce(self, i, partition_id):
def get_reducable_portion(key):
grad = self.param_dict[key].grad
total_elements = grad.numel()
start = self.grad_start_offset[i][partition_id][key]
num_elements = min(
total_elements - start,
self.partition_size[i] -
self.grad_partition_insertion_offset[i][partition_id][key])
if not pg_correctness_test:
if num_elements == total_elements:
return grad
else:
return grad.contiguous().view(-1).narrow(0,
int(start),
int(num_elements))
else:
if num_elements == total_elements:
return grad.clone()
else:
return grad.clone().contiguous().view(-1).narrow(
0,
int(start),
int(num_elements))
grads_to_reduce = []
for key in self.is_grad_computed[i][partition_id]:
grad = get_reducable_portion(key)
grads_to_reduce.append(grad)
return grads_to_reduce
def sequential_execution(self, function, message, group=None):
if group is None:
group = self.dp_process_group
if dist.get_rank(group=group) == 0:
logger.info(message)
for id in range(dist.get_world_size(group=group)):
if id == dist.get_rank(group=group):
function()
dist.barrier(group=group)
def set_none_gradients_to_zero(self, i, partition_id):
for param_id in self.is_grad_computed[i][partition_id]:
param = self.param_dict[param_id]
if param.grad is None:
param.grad = torch.zero_like(param)
######################Reduction Related Methods##############################
def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None):
rank = None
tensor = flatten(bucket)
tensor_to_allreduce = tensor
if pg_correctness_test:
allreduce_always_fp32 = True
if allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group))
if rank is None:
# "All Reducing"
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
else:
global_rank = _get_global_rank(self.dp_process_group, rank)
dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group)
if allreduce_always_fp32 and tensor is not tensor_to_allreduce:
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
tensor.copy_(tensor_to_allreduce)
return tensor
#if rank is specified do a reduction instead of an allreduce
def allreduce_and_copy(self, small_bucket, rank=None, log=None):
if self.overlap_comm:
torch.cuda.synchronize()
if self.previous_reduced_grads is not None:
# previous_reduced_grads has the previous reduced grads,
# now it is safe to clear.
for param in self.previous_reduced_grads:
param.grad = None
self.previous_reduced_grads = None
stream = self.reduction_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log)
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)):
buf.copy_(synced)
def allreduce_no_retain(self,
bucket,
numel_per_bucket=500000000,
rank=None,
log=None):
small_bucket = []
numel = 0
for tensor in bucket:
small_bucket.append(tensor)
numel = numel + tensor.numel()
if numel > numel_per_bucket:
self.allreduce_and_copy(small_bucket, rank=rank, log=None)
small_bucket = []
if len(small_bucket) > 0:
self.allreduce_and_copy(small_bucket, rank=rank, log=log)
#allows using reduction of gradients instead of using all_reduce
def buffered_reduce_fallback(self,
rank,
grads,
elements_per_buffer=500000000,
log=None):
split_buckets = split_half_float_double(grads)
for i, bucket in enumerate(split_buckets):
self.allreduce_no_retain(bucket,
numel_per_bucket=elements_per_buffer,
rank=rank,
log=log)
#############################################################################
#############################################################################
#############################################################################
#views the tensor as multiple partitions and returns
#those partitions
def get_data_parallel_partitions(self, tensor):
partitions = []
dp = dist.get_world_size(group=self.dp_process_group)
dp_id = dist.get_rank(group=self.dp_process_group)
total_num_elements = tensor.numel()
base_size = total_num_elements // dp
remaining = total_num_elements % dp
start = 0
for id in range(dp):
partition_size = base_size
if id < remaining:
partition_size = partition_size + 1
partitions.append(tensor.narrow(0, start, partition_size))
start = start + partition_size
return partitions
def get_partition_info(self, tensor_list, partition_size, partition_id):
params_in_partition = []
params_not_in_partition = []
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for tensor in tensor_list:
tensor_size = tensor.numel()
if (current_index >= start_index and current_index < end_index):
params_in_partition.append(tensor)
elif start_index > current_index and start_index < (current_index +
tensor_size):
params_in_partition.append(tensor)
assert (first_offset==0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
else:
params_not_in_partition.append(tensor)
current_index = current_index + tensor_size
return params_in_partition, params_not_in_partition, first_offset
def zero_grad(self, set_grads_to_None=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def _model_parallel_all_reduce(self, tensor, op):
""" Perform all reduce within model parallel group, if any.
"""
if self.model_parallel_group is None:
torch.distributed.all_reduce(tensor=tensor, op=op)
else:
torch.distributed.all_reduce(tensor=tensor,
op=op,
group=self.model_parallel_group)
def get_grad_norm_direct(self, gradients, params, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=self.dp_process_group)
# Take max across all GPUs.
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
#if dist.get_rank() == 0:
# logger.info(f"Total Norm begining {total_norm}")
for g, p in zip(gradients, params):
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
#creates a flat fused tensor from the tensor list starting at the first_offset
#in the first tensor of the list. If there are not enough elements in the tensor
#list then the flat tensor will be padded with zeros
def get_flat_partition(self,
tensor_list,
first_offset,
partition_size,
dtype,
device,
return_tensor_list=False):
flat_tensor_list = []
current_size = 0
for i, tensor in enumerate(tensor_list):
if tensor.grad is None:
continue
tensor = tensor.grad
num_elements = tensor.numel()
tensor_offset = 0
#we need to offset to get to the right element
if i == 0 and first_offset > 0:
tensor_offset = first_offset
num_elements = num_elements - tensor_offset
#we dont need all elements of the tensor
if num_elements > (partition_size - current_size):
num_elements = partition_size - current_size
#we need a narrow view of the tensor based on the tensor offset and number of elements that
#we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel():
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
0,
int(tensor_offset),
int(num_elements)))
else:
flat_tensor_list.append(tensor)
current_size = current_size + num_elements
#this means its the last partition and does not align with the dp boundary. We need to pad before flattening
if current_size < partition_size:
flat_tensor_list.append(
torch.zeros(int(partition_size - current_size),
dtype=dtype,
device=device))
if return_tensor_list:
return flat_tensor_list
return _flatten_dense_tensors(flat_tensor_list)
def free_grad_in_param_list(self, param_list):
for p in param_list:
p.grad = None
def reset_cpu_buffers(self):
self.norm_for_param_grads = {}
self.local_overflow = False
def step(self, closure=None):
"""
Not supporting closure.
"""
self.micro_step_id = -1
if self.cpu_offload:
torch.cuda.current_stream().wait_stream(self.migration_stream)
see_memory_usage(f"In step before checking overflow")
# First compute norm for all group so we know if there is overflow
self.check_overflow()
timers = self.timers
prev_scale = self.loss_scale
self._update_scale(self.overflow)
if self.overflow:
see_memory_usage('After overflow before clearing gradients')
self.zero_grad()
if self.cpu_offload:
self.reset_cpu_buffers()
else:
self.averaged_gradients = {}
see_memory_usage('After overflow after clearing gradients')
logger.info(
"[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(dist.get_rank(),
prev_scale,
self.loss_scale))
timers('optimizer_gradients').start()
timers('optimizer_gradients').stop()
timers('optimizer_step').start()
timers('optimizer_step').stop()
timers('optimizer_allgather').start()
timers('optimizer_allgather').stop()
return
timers('optimizer_gradients').start()
norm_groups = []
single_partition_grad_groups = []
skip = False
partition_id = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups):
if self.cpu_offload:
norm_groups.append(
self.complete_grad_norm_calculation_for_cpu_offload(
self.params_in_partition[i]))
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
else:
norm_groups.append(
self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i]))
#free gradients for all the prameters that are not updated by this process
self.free_grad_in_param_list(self.params_not_in_partition[i])
#create a flat gradients for parameters updated by this process
# If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
single_grad_partition = flatten_dense_tensors_aligned(
self.averaged_gradients[i],
int(self.partition_size[i])).to(
self.single_partition_of_fp32_groups[i].dtype)
else:
single_grad_partition = _flatten_dense_tensors(
self.averaged_gradients[i]).to(
self.single_partition_of_fp32_groups[i].dtype)
assert single_grad_partition.numel() == self.partition_size[i], \
"averaged gradients have different number of elements that partition size {} {} {} {}".format(single_grad_partition.numel(), self.partition_size[i], i, partition_id)
self.single_partition_of_fp32_groups[i].grad = single_grad_partition
#release all the gradient since we have already created a necessary copy in dp_grad_partition
self.free_grad_in_param_list(self.params_in_partition[i])
self.averaged_gradients[i] = None
single_partition_grad_groups.append(single_grad_partition)
self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
timers('optimizer_gradients').stop()
#torch.set_num_threads(12)
timers('optimizer_step').start()
if self.deepspeed_adam_offload:
self.optimizer.step(fp16_param_groups=self.parallel_partitioned_fp16_groups)
#self.optimizer.step()
#for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
# fp16_partitions[partition_id].data.copy_(fp32_partition.data)
else:
self.optimizer.step()
#get rid of the fp32 gradients. Not needed anymore
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)
timers('optimizer_step').stop()
if self.cpu_offload:
self.reset_cpu_buffers()
timers('optimizer_allgather').start()
#gather the updated weights from everyone
for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups):
#Sequential AllGather Best of both worlds
dp_world_size = dist.get_world_size(group=self.dp_process_group)
num_shards = max(
1,
partitioned_params[partition_id].numel() * dp_world_size //
self.allgather_bucket_size)
shard_size = partitioned_params[partition_id].numel() // num_shards
num_elements = shard_size
assert shard_size * num_shards <= partitioned_params[partition_id].numel()
for shard_id in range(num_shards):
if shard_id == (num_shards - 1):
num_elements = partitioned_params[partition_id].numel(
) - shard_id * shard_size
shard_list = []
for dp_id in range(dp_world_size):
curr_shard = partitioned_params[dp_id].narrow(
0,
shard_id * shard_size,
num_elements).detach()
shard_list.append(curr_shard)
dist.all_gather(shard_list,
shard_list[partition_id],
group=self.dp_process_group)
timers('optimizer_allgather').stop()
# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
timers.log(
names=['optimizer_gradients',
'optimizer_step',
'optimizer_allgather'])
see_memory_usage('After zero_optimizer step')
return
def unscale_and_clip_grads(self, grad_groups_flat, norm_groups):
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)
# compute combined scale factor for this group
combined_scale = self.loss_scale
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.loss_scale
for grad in grad_groups_flat:
if isinstance(grad, list):
sub_partitions = grad
for g in sub_partitions:
g.data.mul_(1. / combined_scale)
else:
grad.data.mul_(1. / combined_scale)
def _check_overflow(self, partition_gradients=True):
self.overflow = self.has_overflow(partition_gradients)
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params, is_grad_list=False):
for p in params:
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
return True
return False
def has_overflow_partitioned_grads_serial(self):
for i in range(len(self.fp16_groups)):
for j, grad in enumerate(self.averaged_gradients[i]):
if grad is not None and self._has_inf_or_nan(grad.data, j):
return True
return False
def has_overflow(self, partition_gradients=True):
if partition_gradients:
overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial(
)
overflow_gpu = torch.cuda.ByteTensor([overflow])
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=self.dp_process_group)
else:
params = []
for group in self.fp16_groups:
for param in group:
params.append(param)
overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients)
overflow_gpu = torch.cuda.ByteTensor([overflow])
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
self._model_parallel_all_reduce(tensor=overflow_gpu,
op=torch.distributed.ReduceOp.MAX)
overflow = overflow_gpu[0].item()
return bool(overflow)
# `x` is a torch.Tensor
@staticmethod
def _has_inf_or_nan(x, j=None):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
return False
def backward(self, loss, retain_graph=False):
"""
:attr:`backward` performs the following steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
self.micro_step_id += 1
if self.cpu_offload:
torch.cuda.current_stream().wait_stream(self.migration_stream)
if self.contiguous_gradients:
self.ipg_buffer = []
buf_0 = torch.empty(int(self.reduce_bucket_size * 4.5),
dtype=torch.half,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_0)
# Use double buffers to avoid data access conflict when overlap_comm is enabled.
if self.overlap_comm:
buf_1 = torch.empty(int(self.reduce_bucket_size * 4.5),
dtype=torch.half,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_1)
self.ipg_index = 0
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
def check_overflow(self, partition_gradients=True):
self._check_overflow(partition_gradients)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
# Return group tensor after removing paddings that are added for alignment to DP world size.
# This method works on the assumption that each group contains a single flattened tensor.
def _get_groups_without_padding(self, groups_with_padding):
groups_without_padding = []
for i, group in enumerate(groups_with_padding):
lean_length = group.numel() - self.groups_padding[i]
groups_without_padding.append(group[:lean_length])
return groups_without_padding
# Return optimizer state after removing paddings that are added for alignment.
def _get_state_without_padding(self, state_with_padding, padding):
lean_state = {}
for key, value in state_with_padding.items():
if torch.is_tensor(value):
lean_length = value.numel() - padding
lean_state[key] = value[:lean_length]
else:
lean_state[key] = value
return lean_state
# Return base optimizer states.
# This method assumes that each param group contains a single flattened tensor.
def _get_base_optimizer_state(self):
optimizer_groups_state = []
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
lean_optimizer_state = self._get_state_without_padding(
self.optimizer.state[p],
self.groups_padding[i])
optimizer_groups_state.append(lean_optimizer_state)
return optimizer_groups_state
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['base_optimizer_state'] = self._get_base_optimizer_state()
state_dict['zero_stage'] = ZERO_OPTIMIZATION_GRADIENTS
state_dict['partition_count'] = self.partition_count
# Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding(
self.single_partition_of_fp32_groups)
state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding
# if self.cpu_offload:
# state_dict_tmp = async_copy_to(state_dict,
# 'cpu',
# torch.cuda.current_stream())
# state_dict = state_dict_tmp
return state_dict
# Restore base optimizer fp32 weights from checkpoint by:
# 1) Merging fp32 weights from checkpoints of all partitions
# 2) Extracting fp32 weights for current partition from merged weights
# 3) Using extracted weights to update base optimizer weights directly.
def _restore_from_fp32_weights(self, all_state_dict):
partition_id = dist.get_rank(group=self.dp_process_group)
merged_single_partition_of_fp32_groups = []
for i in range(len(self.single_partition_of_fp32_groups)):
merged_partitions = [
sd['single_partition_of_fp32_groups'][i] for sd in all_state_dict
]
flat_merged_partitions = flatten_dense_tensors_aligned(
merged_partitions,
dist.get_world_size(group=self.dp_process_group))
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions)
merged_single_partition_of_fp32_groups.append(dp_partitions[partition_id])
for current, saved in zip(self.single_partition_of_fp32_groups, merged_single_partition_of_fp32_groups):
current.data.copy_(saved.data)
# Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self):
partition_id = dist.get_rank(group=self.dp_process_group)
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp32_partition.data.copy_(fp16_partitions[partition_id].data)
# Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
self._restore_from_fp16_weights()
# Extract optimizer state for current partition from merged states of all partitions
def _partition_base_optimizer_state(self, state_key, all_partition_states):
partition_id = dist.get_rank(group=self.dp_process_group)
alignment = dist.get_world_size(group=self.dp_process_group)
if torch.is_tensor(all_partition_states[0]):
flat_merged_partitions = flatten_dense_tensors_aligned(
all_partition_states,
alignment)
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions)
return dp_partitions[partition_id]
else:
# Assume non-tensor states are not partitioned and equal across ranks, so return first one
return all_partition_states[0]
# Restore base optimizer state from checkpoint by
# 1) Merging optimizer state from checkpoints of all partitions
# 2) Extracting optimizer state for current partition from the merged state
# 3) Using the extracted value to directly update the base optimizer.
def _restore_base_optimizer_state(self, all_state_dict):
base_optimizer_group_states = []
for i in range(len(self.optimizer.param_groups)):
partition_states = {}
all_partition_group_states = [
sd['base_optimizer_state'][i] for sd in all_state_dict
]
for key in all_partition_group_states[0].keys():
all_partition_states = [
all_states[key] for all_states in all_partition_group_states
]
partition_states[key] = self._partition_base_optimizer_state(
key,
all_partition_states)
base_optimizer_group_states.append(partition_states)
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
for key, saved in base_optimizer_group_states[i].items():
if torch.is_tensor(self.optimizer.state[p][key]):
self.optimizer.state[p][key].data.copy_(saved.data)
else:
self.optimizer.state[p][key] = saved
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
r"""Loading ZeRO checkpoint
Arguments:
state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
Note that the number of saved partitions may differ from number of loading partitions to support
changing GPU count, specifically DP world size, between saving and loading checkpoints.
load_optimizer_states: Boolean indicating whether or not to load base optimizer states
load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32
copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss).
"""
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict_list[0]['loss_scaler']
self.dynamic_loss_scale = state_dict_list[0]['dynamic_loss_scale']
self.overflow = state_dict_list[0]['overflow']
if load_optimizer_states:
self._restore_base_optimizer_state(state_dict_list)
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 1 if changing DP degree and option 2 otherwise.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
if load_from_fp32_weights:
self._restore_from_fp32_weights(state_dict_list)
else:
self._restore_from_fp16_weights()
def _handle_overflow(cpu_sum, x, i):
import math
rank = torch.distributed.get_rank()
if rank == 0:
t_i = -1
for v_i, v in enumerate(x.data.contiguous().view(-1)):
if not math.isfinite(float(v)):
t_i = v_i
break
logger.info(
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
)
import torch
import torch.distributed as dist
import apex
from deepspeed.utils import logger
from deepspeed.ops.adam import DeepSpeedCPUAdam
def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
......@@ -20,3 +21,17 @@ def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
if rank in ranks:
my_group = group
return my_group
ZERO_SUPPORTED_OPTIMIZERS = [
torch.optim.Adam,
apex.optimizers.FusedAdam,
DeepSpeedCPUAdam
]
def is_zero_supported_optimizer(optimizer):
print(
f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}'
)
return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS
......@@ -162,6 +162,14 @@ Please see the [core API doc](https://deepspeed.readthedocs.io/) for more detail
With DeepSpeed, the user can choose to use a high performance implementation of ADAM from
NVIDIA, or any training optimizer that extends torch's `torch.optim.Optimizer` class.
### CPU-Adam: High-Performance vectorized implementation of Adam
We introduce an efficient implementation of Adam optimizer on CPU that improves the parameter-update
performance by nearly an order of magnitude. We use the AVX SIMD instructions on Intel-x86 architecture
for the CPU-Adam implementation. We support both AVX-512 and AVX-2 instruction sets. DeepSpeed uses
AVX-2 by defualt which can be switched to AVX-512 by setting the build flag, `DS_BUILD_AVX512` to 1 when
installing DeepSpeed. Using AVX-512, we observe 5.1x to 6.5x speedups considering the model-size between
1 to 10 billion parameters with respect to torch-adam.
### Memory bandwidth optimized FP16 Optimizer
Mixed precision training is handled by the DeepSpeed FP16 Optimizer. This optimizer not
only handles FP16 training but is also highly efficient. The performance of weight update
......
......@@ -167,6 +167,7 @@ overview](/features/) for descriptions and usage.
* Automatic loss scaling with mixed precision
* [Training Optimizers](/features/#training-optimizers)
* Fused Adam optimizer and arbitrary `torch.optim.Optimizer`
* CPU-Adam: High-Performance vectorized Adam
* Memory bandwidth optimized FP16 Optimizer
* Large Batch Training with LAMB Optimizer
* Memory efficient Training with ZeRO Optimizer
......
......@@ -164,10 +164,10 @@ if [ ! -f $hostfile ]; then
local_only=1
fi
#if [ "$skip_requirements" == "0" ]; then
# # Ensure dependencies are installed locally
# $PIP_SUDO $PIP_INSTALL -r requirements.txt
#fi
if [ "$skip_requirements" == "0" ]; then
# Ensure dependencies are installed locally
$PIP_SUDO $PIP_INSTALL -r requirements/requirements.txt
fi
# Build wheels
if [ "$third_party_install" == "1" ]; then
......@@ -220,10 +220,10 @@ else
tmp_wheel_path="/tmp/deepspeed_wheels"
pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; else mkdir -pv $tmp_wheel_path; fi"
#pdcp -w $hosts requirements/*.txt ${tmp_wheel_path}/
#if [ "$skip_requirements" == "0" ]; then
# pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL -r ${tmp_wheel_path}/requirements.txt"
#fi
pdcp -w $hosts requirements/requirements.txt ${tmp_wheel_path}/
if [ "$skip_requirements" == "0" ]; then
pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL -r ${tmp_wheel_path}/requirements.txt"
fi
if [ "$third_party_install" == "1" ]; then
pdsh -w $hosts "$PIP_SUDO pip uninstall -y apex"
pdcp -w $hosts third_party/apex/dist/apex*.whl $tmp_wheel_path/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment