Unverified Commit 41db1c2f authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files
parent 79093d74
#include "cpu_adam.h"
#include <cuda_runtime_api.h>
#include <math.h>
#include <omp.h>
#include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#include "custom_cuda_layers.h"
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
// C++ interface
void Adam_Optimizer::Step(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params)
{
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float bias_correction1 = 1 - _betta1_t;
float bias_correction2 = 1 / sqrt(1 - _betta2_t);
float step_size = -1 * _alpha / bias_correction1;
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
AVX_Data weight_decay4;
if (_weight_decay > 0) weight_decay4.data = SIMD_SET(_weight_decay);
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
AVX_Data grad_4;
grad_4.data = SIMD_LOAD(grads + i);
AVX_Data momentum_4;
momentum_4.data = SIMD_LOAD(_exp_avg + i);
AVX_Data variance_4;
variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
AVX_Data param_4;
param_4.data = SIMD_LOAD(_params + i);
if (_weight_decay > 0)
grad_4.data = SIMD_FMA(param_4.data, weight_decay4.data, grad_4.data);
momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
momentum_4.data = SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);
variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data);
grad_4.data = SIMD_MUL(grad_4.data, grad_4.data);
variance_4.data = SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data);
grad_4.data = SIMD_SQRT(variance_4.data);
grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
SIMD_STORE(_params + i, param_4.data);
if (dev_params) SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4.data);
SIMD_STORE(_exp_avg + i, momentum_4.data);
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + t,
copy_size,
Context::Instance().GetCurrentStream());
_buf_index = !_buf_index;
}
}
#endif
if (_param_size > rounded_size) {
#pragma omp parallel for
for (size_t k = rounded_size; k < _param_size; k++) {
float grad = grads[k];
float param = _params[k];
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0) grad = param * _weight_decay + grad;
momentum *= momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * bias_correction2 + _eps;
grad = momentum / grad;
param = grad * step_size + param;
if (dev_params) _doubled_buffer[_buf_index][k - rounded_size] = (__half)param;
_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + rounded_size,
(_param_size - rounded_size),
Context::Instance().GetCurrentStream());
}
}
}
void Adam_Optimizer::Step_4(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
float bias_correction1 = 1 - _betta1_t;
float bias_correction2 = 1 / sqrt(1 - _betta2_t);
// AVX_Data bias_correction1_4 = SIMD_SET(bias_correction1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 2));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) {
AVX_Data grad_4[4];
grad_4[0].data = SIMD_LOAD(grads + i);
grad_4[1].data = SIMD_LOAD(grads + i + SIMD_WIDTH);
grad_4[2].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 1));
grad_4[3].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 3);
AVX_Data momentum_4[4];
momentum_4[0].data = SIMD_LOAD(_exp_avg + i);
momentum_4[1].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH);
momentum_4[2].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 1));
momentum_4[3].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 3);
AVX_Data variance_4[4];
variance_4[0].data = SIMD_LOAD(_exp_avg_sq + i);
variance_4[1].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH);
variance_4[2].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 1));
variance_4[3].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 3);
AVX_Data param_4[4];
param_4[0].data = SIMD_LOAD(_params + i);
param_4[1].data = SIMD_LOAD(_params + i + SIMD_WIDTH);
param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1));
param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3);
if (_weight_decay > 0) {
AVX_Data weight_decay4;
weight_decay4.data = SIMD_SET(_weight_decay);
grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data);
grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data);
grad_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, grad_4[3].data);
}
momentum_4[0].data = SIMD_MUL(momentum_4[0].data, betta1_4.data);
momentum_4[0].data = SIMD_FMA(grad_4[0].data, betta1_minus1_4.data, momentum_4[0].data);
momentum_4[1].data = SIMD_MUL(momentum_4[1].data, betta1_4.data);
momentum_4[1].data = SIMD_FMA(grad_4[1].data, betta1_minus1_4.data, momentum_4[1].data);
momentum_4[2].data = SIMD_MUL(momentum_4[2].data, betta1_4.data);
momentum_4[2].data = SIMD_FMA(grad_4[2].data, betta1_minus1_4.data, momentum_4[2].data);
momentum_4[3].data = SIMD_MUL(momentum_4[3].data, betta1_4.data);
momentum_4[3].data = SIMD_FMA(grad_4[3].data, betta1_minus1_4.data, momentum_4[3].data);
variance_4[0].data = SIMD_MUL(variance_4[0].data, betta2_4.data);
variance_4[1].data = SIMD_MUL(variance_4[1].data, betta2_4.data);
variance_4[2].data = SIMD_MUL(variance_4[2].data, betta2_4.data);
variance_4[3].data = SIMD_MUL(variance_4[3].data, betta2_4.data);
grad_4[0].data = SIMD_MUL(grad_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_MUL(grad_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_MUL(grad_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_MUL(grad_4[3].data, grad_4[3].data);
variance_4[0].data = SIMD_FMA(grad_4[0].data, betta2_minus1_4.data, variance_4[0].data);
variance_4[1].data = SIMD_FMA(grad_4[1].data, betta2_minus1_4.data, variance_4[1].data);
variance_4[2].data = SIMD_FMA(grad_4[2].data, betta2_minus1_4.data, variance_4[2].data);
variance_4[3].data = SIMD_FMA(grad_4[3].data, betta2_minus1_4.data, variance_4[3].data);
grad_4[0].data = SIMD_SQRT(variance_4[0].data);
grad_4[1].data = SIMD_SQRT(variance_4[1].data);
grad_4[2].data = SIMD_SQRT(variance_4[2].data);
grad_4[3].data = SIMD_SQRT(variance_4[3].data);
grad_4[0].data = SIMD_FMA(grad_4[0].data, bias2_sqrt.data, eps_4.data);
grad_4[1].data = SIMD_FMA(grad_4[1].data, bias2_sqrt.data, eps_4.data);
grad_4[2].data = SIMD_FMA(grad_4[2].data, bias2_sqrt.data, eps_4.data);
grad_4[3].data = SIMD_FMA(grad_4[3].data, bias2_sqrt.data, eps_4.data);
grad_4[0].data = SIMD_DIV(momentum_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_DIV(momentum_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data);
param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(grad_4[3].data, step_size_4.data, param_4[3].data);
SIMD_STORE(_params + i, param_4[0].data);
SIMD_STORE(_params + i + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_params + i + (SIMD_WIDTH << 1), param_4[2].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 3, param_4[3].data);
if (dev_params) {
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4[0].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1),
param_4[2].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3].data);
}
SIMD_STORE(_exp_avg + i, momentum_4[0].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH, momentum_4[1].data);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 1), momentum_4[2].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 3, momentum_4[3].data);
SIMD_STORE(_exp_avg_sq + i, variance_4[0].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH, variance_4[1].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 1), variance_4[2].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 3, variance_4[3].data);
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + t,
copy_size,
Context::Instance().GetCurrentStream());
_buf_index = !_buf_index;
}
}
#endif
if (_param_size > rounded_size)
Step((_params + rounded_size),
(grads + rounded_size),
(_exp_avg + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params));
}
int create_adam_optimizer(int optimizer_id,
float alpha = 1e-3,
float betta1 = 0.9,
float betta2 = 0.999,
float eps = 1e-8,
float weight_decay = 0)
{
auto opt = std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps, weight_decay);
s_optimizers[optimizer_id] = opt;
#if defined(__AVX512__)
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with AVX512 arithmetic capability." << std::endl;
#else
#if defined(__AVX256__)
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with AVX2 arithmetic capability." << std::endl;
#else
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with scalar arithmetic capability." << std::endl;
#endif
#endif
return 0;
}
void Adam_Optimizer::Step_8(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
float bias_correction1 = 1 - _betta1_t;
float bias_correction2 = 1 / sqrt(1 - _betta2_t);
// AVX_Data bias_correction1_4 = SIMD_SET(bias_correction1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 3));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) {
AVX_Data grad_4[8];
grad_4[0].data = SIMD_LOAD(grads + i);
grad_4[1].data = SIMD_LOAD(grads + i + SIMD_WIDTH);
grad_4[2].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 1));
grad_4[3].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 3);
grad_4[4].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 2));
grad_4[5].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 5);
grad_4[6].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 6);
grad_4[7].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 7);
AVX_Data momentum_4[8];
momentum_4[0].data = SIMD_LOAD(_exp_avg + i);
momentum_4[1].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH);
momentum_4[2].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 1));
momentum_4[3].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 3);
momentum_4[4].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 2));
momentum_4[5].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 5);
momentum_4[6].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 6);
momentum_4[7].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 7);
AVX_Data variance_4[8];
variance_4[0].data = SIMD_LOAD(_exp_avg_sq + i);
variance_4[1].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH);
variance_4[2].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 1));
variance_4[3].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 3);
variance_4[4].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 2));
variance_4[5].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 5);
variance_4[6].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 6);
variance_4[7].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 7);
AVX_Data param_4[8];
param_4[0].data = SIMD_LOAD(_params + i);
param_4[1].data = SIMD_LOAD(_params + i + SIMD_WIDTH);
param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1));
param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3);
param_4[4].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 2));
param_4[5].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 5);
param_4[6].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 6);
param_4[7].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 7);
if (_weight_decay > 0) {
AVX_Data weight_decay4;
weight_decay4.data = SIMD_SET(_weight_decay);
grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data);
grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data);
grad_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, grad_4[3].data);
grad_4[4].data = SIMD_FMA(param_4[4].data, weight_decay4.data, grad_4[4].data);
grad_4[5].data = SIMD_FMA(param_4[5].data, weight_decay4.data, grad_4[5].data);
grad_4[6].data = SIMD_FMA(param_4[6].data, weight_decay4.data, grad_4[6].data);
grad_4[7].data = SIMD_FMA(param_4[7].data, weight_decay4.data, grad_4[7].data);
}
momentum_4[0].data = SIMD_MUL(momentum_4[0].data, betta1_4.data);
momentum_4[0].data = SIMD_FMA(grad_4[0].data, betta1_minus1_4.data, momentum_4[0].data);
momentum_4[1].data = SIMD_MUL(momentum_4[1].data, betta1_4.data);
momentum_4[1].data = SIMD_FMA(grad_4[1].data, betta1_minus1_4.data, momentum_4[1].data);
momentum_4[2].data = SIMD_MUL(momentum_4[2].data, betta1_4.data);
momentum_4[2].data = SIMD_FMA(grad_4[2].data, betta1_minus1_4.data, momentum_4[2].data);
momentum_4[3].data = SIMD_MUL(momentum_4[3].data, betta1_4.data);
momentum_4[3].data = SIMD_FMA(grad_4[3].data, betta1_minus1_4.data, momentum_4[3].data);
momentum_4[4].data = SIMD_MUL(momentum_4[4].data, betta1_4.data);
momentum_4[4].data = SIMD_FMA(grad_4[4].data, betta1_minus1_4.data, momentum_4[4].data);
momentum_4[5].data = SIMD_MUL(momentum_4[5].data, betta1_4.data);
momentum_4[5].data = SIMD_FMA(grad_4[5].data, betta1_minus1_4.data, momentum_4[5].data);
momentum_4[6].data = SIMD_MUL(momentum_4[6].data, betta1_4.data);
momentum_4[6].data = SIMD_FMA(grad_4[6].data, betta1_minus1_4.data, momentum_4[6].data);
momentum_4[7].data = SIMD_MUL(momentum_4[7].data, betta1_4.data);
momentum_4[7].data = SIMD_FMA(grad_4[7].data, betta1_minus1_4.data, momentum_4[7].data);
variance_4[0].data = SIMD_MUL(variance_4[0].data, betta2_4.data);
variance_4[1].data = SIMD_MUL(variance_4[1].data, betta2_4.data);
variance_4[2].data = SIMD_MUL(variance_4[2].data, betta2_4.data);
variance_4[3].data = SIMD_MUL(variance_4[3].data, betta2_4.data);
variance_4[4].data = SIMD_MUL(variance_4[4].data, betta2_4.data);
variance_4[5].data = SIMD_MUL(variance_4[5].data, betta2_4.data);
variance_4[6].data = SIMD_MUL(variance_4[6].data, betta2_4.data);
variance_4[7].data = SIMD_MUL(variance_4[7].data, betta2_4.data);
grad_4[0].data = SIMD_MUL(grad_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_MUL(grad_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_MUL(grad_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_MUL(grad_4[3].data, grad_4[3].data);
grad_4[4].data = SIMD_MUL(grad_4[4].data, grad_4[4].data);
grad_4[5].data = SIMD_MUL(grad_4[5].data, grad_4[5].data);
grad_4[6].data = SIMD_MUL(grad_4[6].data, grad_4[6].data);
grad_4[7].data = SIMD_MUL(grad_4[7].data, grad_4[7].data);
variance_4[0].data = SIMD_FMA(grad_4[0].data, betta2_minus1_4.data, variance_4[0].data);
variance_4[1].data = SIMD_FMA(grad_4[1].data, betta2_minus1_4.data, variance_4[1].data);
variance_4[2].data = SIMD_FMA(grad_4[2].data, betta2_minus1_4.data, variance_4[2].data);
variance_4[3].data = SIMD_FMA(grad_4[3].data, betta2_minus1_4.data, variance_4[3].data);
variance_4[4].data = SIMD_FMA(grad_4[4].data, betta2_minus1_4.data, variance_4[4].data);
variance_4[5].data = SIMD_FMA(grad_4[5].data, betta2_minus1_4.data, variance_4[5].data);
variance_4[6].data = SIMD_FMA(grad_4[6].data, betta2_minus1_4.data, variance_4[6].data);
variance_4[7].data = SIMD_FMA(grad_4[7].data, betta2_minus1_4.data, variance_4[7].data);
grad_4[0].data = SIMD_SQRT(variance_4[0].data);
grad_4[1].data = SIMD_SQRT(variance_4[1].data);
grad_4[2].data = SIMD_SQRT(variance_4[2].data);
grad_4[3].data = SIMD_SQRT(variance_4[3].data);
grad_4[4].data = SIMD_SQRT(variance_4[4].data);
grad_4[5].data = SIMD_SQRT(variance_4[5].data);
grad_4[6].data = SIMD_SQRT(variance_4[6].data);
grad_4[7].data = SIMD_SQRT(variance_4[7].data);
grad_4[0].data = SIMD_FMA(grad_4[0].data, bias2_sqrt.data, eps_4.data);
grad_4[1].data = SIMD_FMA(grad_4[1].data, bias2_sqrt.data, eps_4.data);
grad_4[2].data = SIMD_FMA(grad_4[2].data, bias2_sqrt.data, eps_4.data);
grad_4[3].data = SIMD_FMA(grad_4[3].data, bias2_sqrt.data, eps_4.data);
grad_4[4].data = SIMD_FMA(grad_4[4].data, bias2_sqrt.data, eps_4.data);
grad_4[5].data = SIMD_FMA(grad_4[5].data, bias2_sqrt.data, eps_4.data);
grad_4[6].data = SIMD_FMA(grad_4[6].data, bias2_sqrt.data, eps_4.data);
grad_4[7].data = SIMD_FMA(grad_4[7].data, bias2_sqrt.data, eps_4.data);
grad_4[0].data = SIMD_DIV(momentum_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_DIV(momentum_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data);
grad_4[4].data = SIMD_DIV(momentum_4[4].data, grad_4[4].data);
grad_4[5].data = SIMD_DIV(momentum_4[5].data, grad_4[5].data);
grad_4[6].data = SIMD_DIV(momentum_4[6].data, grad_4[6].data);
grad_4[7].data = SIMD_DIV(momentum_4[7].data, grad_4[7].data);
param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(grad_4[3].data, step_size_4.data, param_4[3].data);
param_4[4].data = SIMD_FMA(grad_4[4].data, step_size_4.data, param_4[4].data);
param_4[5].data = SIMD_FMA(grad_4[5].data, step_size_4.data, param_4[5].data);
param_4[6].data = SIMD_FMA(grad_4[6].data, step_size_4.data, param_4[6].data);
param_4[7].data = SIMD_FMA(grad_4[7].data, step_size_4.data, param_4[7].data);
SIMD_STORE(_params + i, param_4[0].data);
SIMD_STORE(_params + i + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_params + i + (SIMD_WIDTH << 1), param_4[2].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 3, param_4[3].data);
SIMD_STORE(_params + i + (SIMD_WIDTH << 2), param_4[4].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 5, param_4[5].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 6, param_4[6].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 7, param_4[7].data);
if (dev_params) {
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4[0].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1),
param_4[2].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 2),
param_4[4].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 5, param_4[5].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 6, param_4[6].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 7, param_4[7].data);
}
SIMD_STORE(_exp_avg + i, momentum_4[0].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH, momentum_4[1].data);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 1), momentum_4[2].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 3, momentum_4[3].data);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 2), momentum_4[4].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 5, momentum_4[5].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 6, momentum_4[6].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 7, momentum_4[7].data);
SIMD_STORE(_exp_avg_sq + i, variance_4[0].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH, variance_4[1].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 1), variance_4[2].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 3, variance_4[3].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 2), variance_4[4].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 5, variance_4[5].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 6, variance_4[6].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 7, variance_4[7].data);
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + t,
copy_size,
Context::Instance().GetCurrentStream());
_buf_index = !_buf_index;
}
}
#endif
if (_param_size > rounded_size)
Step_4((_params + rounded_size),
(grads + rounded_size),
(_exp_avg + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params));
}
int ds_adam_step(int optimizer_id,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq)
{
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep();
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0));
return 0;
}
int ds_adam_step_plus_copy(int optimizer_id,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq,
torch::Tensor& gpu_params)
{
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
auto grads_c = grads.contiguous();
float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
__half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep();
opt->Step_8(
params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr);
return 0;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
m.def("adam_update_copy",
&ds_adam_step_plus_copy,
"DeepSpeed CPU Adam update and param copy (C++)");
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
}
#include "custom_cuda_layers.h"
__global__ void param_update_kernel(const float* input, __half* output, int size)
{
const float4* input_cast = reinterpret_cast<const float4*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
int id = blockIdx.x * blockDim.x + threadIdx.x;
if (id < size) {
float4 data = input_cast[id];
float2 cast_data;
__half* output_h = reinterpret_cast<__half*>(&cast_data);
output_h[0] = (__half)data.x;
output_h[1] = (__half)data.y;
output_h[2] = (__half)data.z;
output_h[3] = (__half)data.w;
output_cast[id] = cast_data;
}
}
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream)
{
int threads = 512;
size /= 4;
dim3 grid_dim((size - 1) / threads + 1);
dim3 block_dim(threads);
param_update_kernel<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
}
#pragma once
#include <cpuid.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
#include <x86intrin.h>
#include <cassert>
#include "context.h"
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#define TILE (1024 * 1024 * 1024)
#if defined(__AVX512__)
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
#define SIMD_SET(x) _mm512_set1_ps(x)
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_WIDTH 16
#else
#if defined(__AVX256__)
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm256_loadu_ps(x)
#define SIMD_SET(x) _mm256_set1_ps(x)
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_WIDTH 8
#endif
#endif
class Adam_Optimizer {
public:
Adam_Optimizer(float alpha = 1e-3,
float betta1 = 0.9,
float betta2 = 0.999,
float eps = 1e-8,
float weight_decay = 0)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_buf_index(false)
{
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
}
~Adam_Optimizer()
{
cudaFreeHost(_doubled_buffer[0]);
cudaFreeHost(_doubled_buffer[1]);
}
void Step(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t param_size,
__half* dev_param = nullptr);
void Step_4(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sa,
size_t param_size,
__half* dev_param = nullptr);
void Step_8(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params = nullptr);
inline void IncrementStep()
{
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
private:
#if defined(__AVX512__) or defined(__AVX256__)
union AVX_Data {
#if defined(__AVX512__)
__m512 data;
#else
__m256 data;
#endif
// float data_f[16];
};
#endif
float _alpha;
float _betta1;
float _betta2;
float _eps;
float _weight_decay;
float _betta1_t;
float _betta2_t;
float* _doubled_buffer[2];
bool _buf_index;
};
...@@ -264,3 +264,5 @@ void launch_fuse_transpose_bias_kernel(const T* inp, ...@@ -264,3 +264,5 @@ void launch_fuse_transpose_bias_kernel(const T* inp,
int rows, int rows,
int cols, int cols,
cudaStream_t stream); cudaStream_t stream);
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream);
...@@ -4,16 +4,18 @@ Copyright 2020 The Microsoft DeepSpeed Team ...@@ -4,16 +4,18 @@ Copyright 2020 The Microsoft DeepSpeed Team
import sys import sys
import types import types
from deepspeed.runtime.engine import DeepSpeedEngine from . import ops
from deepspeed.runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from deepspeed.runtime.lr_schedules import add_tuning_arguments from .runtime.engine import DeepSpeedEngine
from deepspeed.runtime.config import DeepSpeedConfig from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_ADAM
from deepspeed.runtime.activation_checkpointing import checkpointing from .runtime.lr_schedules import add_tuning_arguments
from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from .runtime.config import DeepSpeedConfig
from deepspeed.utils import logger from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .utils import logger
try: try:
from deepspeed.git_version_info import version, git_hash, git_branch from .git_version_info import version, git_hash, git_branch
except ImportError: except ImportError:
version = "0.0.0+unknown" version = "0.0.0+unknown"
git_hash = None git_hash = None
......
from ..git_version_info import installed_ops as __installed_ops__
from . import lamb
from . import transformer
if __installed_ops__['sparse-attn']:
from . import sparse_attention
if __installed_ops__['cpu-adam']:
from . import adam
from .cpu_adam import DeepSpeedCPUAdam
import math
import torch
import importlib
ds_opt_adam = None
class DeepSpeedCPUAdam(torch.optim.Optimizer):
optimizer_id = 0
def __init__(self,
model_params,
lr=1e-3,
betas=(0.9,
0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False):
default_args = dict(lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad)
super(DeepSpeedCPUAdam, self).__init__(model_params, default_args)
self.opt_id = DeepSpeedCPUAdam.optimizer_id
DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1
global ds_opt_adam
ds_opt_adam = importlib.import_module('deepspeed.ops.adam.cpu_adam_op')
ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay)
def __setstate__(self, state):
super(DeepSpeedCPUAdam, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None, fp16_param_groups=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group_id, group in enumerate(self.param_groups):
for param_id, p in enumerate(group['params']):
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
# State initialization
if len(state) == 0:
print(f'group {group_id} param {param_id} = {p.numel()}')
state['step'] = 0
# gradient momentums
state['exp_avg'] = torch.zeros_like(p.data, device='cpu')
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data, device='cpu')
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
state['step'] += 1
if fp16_param_groups is not None:
p_fp16 = fp16_param_groups[group_id][param_id]
ds_opt_adam.adam_update_copy(self.opt_id,
p.data,
grad,
exp_avg,
exp_avg_sq,
p_fp16)
else:
ds_opt_adam.adam_update(self.opt_id,
p.data,
grad,
exp_avg,
exp_avg_sq)
return loss
import torch
from torch.autograd import Variable
import collections
def async_migrate_to(obj, dev, main_stream=None):
if torch.is_tensor(obj):
obj = Variable(obj)
if isinstance(obj, Variable):
v = obj.cuda(dev, async=True)
if main_stream is not None:
v.data.record_stream(main_stream)
return v
elif isinstance(obj, collections.Mapping):
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
elif isinstance(obj, collections.Sequence):
return [async_copy_to(o, dev, main_stream) for o in obj]
else:
return obj
def async_copy_to(obj, dev, main_stream=None):
if torch.is_tensor(obj):
obj = Variable(obj)
if isinstance(obj, Variable):
target = torch.empty_like(obj, device=dev).copy_(obj)
if main_stream is not None:
target.data.record_stream(main_stream)
return target
elif isinstance(obj, collections.Mapping):
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
elif isinstance(obj, collections.Sequence):
return [async_copy_to(o, dev, main_stream) for o in obj]
...@@ -10,14 +10,21 @@ from deepspeed.runtime.constants import * ...@@ -10,14 +10,21 @@ from deepspeed.runtime.constants import *
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
from deepspeed.runtime.config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys from deepspeed.runtime.config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from deepspeed.runtime.zero.constants import *
from deepspeed.runtime.activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig from deepspeed.runtime.activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
from deepspeed.utils import logger from deepspeed.utils import logger
TENSOR_CORE_ALIGN_SIZE = 8 TENSOR_CORE_ALIGN_SIZE = 8
ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
ADAM_OPTIMIZER = 'adam' ADAM_OPTIMIZER = 'adam'
LAMB_OPTIMIZER = 'lamb' LAMB_OPTIMIZER = 'lamb'
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER] ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
DEEPSPEED_ADAM = 'deepspeed_adam'
DEEPSPEED_OPTIMIZERS = [
ADAM_OPTIMIZER,
LAMB_OPTIMIZER,
ONEBIT_ADAM_OPTIMIZER,
DEEPSPEED_ADAM
]
def get_amp_enabled(param_dict): def get_amp_enabled(param_dict):
...@@ -111,22 +118,9 @@ def get_zero_optimization(param_dict): ...@@ -111,22 +118,9 @@ def get_zero_optimization(param_dict):
def get_zero_reduce_scatter(param_dict): def get_zero_reduce_scatter(param_dict):
return get_scalar_param(param_dict, ZERO_REDUCE_SCATTER, ZERO_REDUCE_SCATTER_DEFAULT)
def get_zero_max_elements_per_comm(param_dict):
return get_scalar_param(param_dict, return get_scalar_param(param_dict,
ZERO_MAX_ELEMENTS_PER_COMM, ZERO_OPTIMIZATION_REDUCE_SCATTER,
ZERO_MAX_ELEMENTS_PER_COMM_DEFAULT) ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT)
def get_allgather_size(param_dict):
return get_scalar_param(param_dict,
ALLGATHER_SIZE,
ALLGATHER_SIZE_DEFAULT) if get_scalar_param(
param_dict,
ALLGATHER_SIZE,
ALLGATHER_SIZE_DEFAULT) > 0 else ALLGATHER_SIZE_DEFAULT
def get_allreduce_always_fp32(param_dict): def get_allreduce_always_fp32(param_dict):
...@@ -493,8 +487,6 @@ class DeepSpeedConfig(object): ...@@ -493,8 +487,6 @@ class DeepSpeedConfig(object):
self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict) self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict)
self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict) self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)
self.allgather_size = get_allgather_size(param_dict)
self.zero_config = DeepSpeedZeroConfig(param_dict) self.zero_config = DeepSpeedZeroConfig(param_dict)
self.zero_optimization_stage = self.zero_config.stage self.zero_optimization_stage = self.zero_config.stage
self.zero_enabled = self.zero_optimization_stage > 0 self.zero_enabled = self.zero_optimization_stage > 0
...@@ -628,15 +620,18 @@ class DeepSpeedConfig(object): ...@@ -628,15 +620,18 @@ class DeepSpeedConfig(object):
':')))) ':'))))
def _do_error_check(self): def _do_error_check(self):
if self.zero_enabled:
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU) assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
assert self.gradient_accumulation_steps, 'DeepSpeedConfig: {} is not defined'.format( assert self.gradient_accumulation_steps, "DeepSpeedConfig: {} is not defined".format(
GRADIENT_ACCUMULATION_STEPS) GRADIENT_ACCUMULATION_STEPS)
if self.zero_enabled:
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
if self.zero_config.cpu_offload is True:
assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS)
#assert self.gradient_accumulation_steps == 1, "DeepSpeedConfig: {}is not supported for {}".format(GRADIENT_ACCUMULATION_STEPS, ZERO_OPTIMIZATION_CPU_OFFLOAD)
def _do_warning_check(self): def _do_warning_check(self):
fp16_enabled = self.fp16_enabled or self.zero_enabled fp16_enabled = self.fp16_enabled or self.zero_enabled
......
...@@ -183,35 +183,6 @@ Gradient clipping should be enabled as: ...@@ -183,35 +183,6 @@ Gradient clipping should be enabled as:
GRADIENT_CLIPPING = 'gradient_clipping' GRADIENT_CLIPPING = 'gradient_clipping'
GRADIENT_CLIPPING_DEFAULT = 0. GRADIENT_CLIPPING_DEFAULT = 0.
#########################################
# ZeRO optimization
#########################################
# ZeRO optimization. By default, this optimization is not enabled.
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
ZERO_FORMAT = '''
ZeRO optimization should be enabled as:
"session_params": {
"zero_optimization": [0|1|2],
"zero_all_gather_size": 200
}
'''
ZERO_OPTIMIZATION = 'zero_optimization'
ZERO_OPTIMIZATION_DEFAULT = 0
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
ZERO_OPTIMIZATION_GRADIENTS = 2
ZERO_OPTIMIZATION_WEIGHTS = 3
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
ZERO_REDUCE_SCATTER = "zero_reduce_scatter"
ZERO_REDUCE_SCATTER_DEFAULT = True
ZERO_MAX_ELEMENTS_PER_COMM = "zero_max_elements_per_comm"
ZERO_MAX_ELEMENTS_PER_COMM_DEFAULT = 5e8
ALLGATHER_SIZE = 'allgather_size'
ALLGATHER_SIZE_DEFAULT = 500000000
######################################### #########################################
# FP32 AllReduce # FP32 AllReduce
######################################### #########################################
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import warnings import warnings
import torch.distributed as dist import torch.distributed as dist
import apex
from apex import amp from apex import amp
from torch.nn.modules import Module from torch.nn.modules import Module
from torch.distributed.distributed_c10d import _get_global_rank from torch.distributed.distributed_c10d import _get_global_rank
...@@ -14,20 +15,20 @@ from tensorboardX import SummaryWriter ...@@ -14,20 +15,20 @@ from tensorboardX import SummaryWriter
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1 from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, \ from deepspeed.runtime.config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, DEEPSPEED_OPTIMIZERS ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, DEEPSPEED_ADAM, DEEPSPEED_OPTIMIZERS
from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \ from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
TORCH_DISTRIBUTED_DEFAULT_PORT, \ TORCH_DISTRIBUTED_DEFAULT_PORT
from deepspeed.runtime.zero.constants import \
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.runtime.csr_tensor import CSRTensor from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.utils import logger from deepspeed.utils import logger
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
...@@ -105,7 +106,6 @@ class DeepSpeedEngine(Module): ...@@ -105,7 +106,6 @@ class DeepSpeedEngine(Module):
collate_fn=None, collate_fn=None,
config_params=None): config_params=None):
super(DeepSpeedEngine, self).__init__() super(DeepSpeedEngine, self).__init__()
self.client_optimizer = optimizer self.client_optimizer = optimizer
self.client_model_parameters = model_parameters self.client_model_parameters = model_parameters
self.client_lr_scheduler = lr_scheduler self.client_lr_scheduler = lr_scheduler
...@@ -266,7 +266,7 @@ class DeepSpeedEngine(Module): ...@@ -266,7 +266,7 @@ class DeepSpeedEngine(Module):
return self._config.train_micro_batch_size_per_gpu return self._config.train_micro_batch_size_per_gpu
def optimizer_name(self): def optimizer_name(self):
return self._config.optimizer_name return self.client_optimizer.__class__.__name__ if self.client_optimizer else self._config.optimizer_name
def optimizer_params(self): def optimizer_params(self):
return self._config.optimizer_params return self._config.optimizer_params
...@@ -292,6 +292,9 @@ class DeepSpeedEngine(Module): ...@@ -292,6 +292,9 @@ class DeepSpeedEngine(Module):
def zero_overlap_comm(self): def zero_overlap_comm(self):
return self._config.zero_config.overlap_comm return self._config.zero_config.overlap_comm
def zero_cpu_offload(self):
return self._config.zero_config.cpu_offload
def zero_optimization_stage(self): def zero_optimization_stage(self):
return self._config.zero_optimization_stage return self._config.zero_optimization_stage
...@@ -310,9 +313,6 @@ class DeepSpeedEngine(Module): ...@@ -310,9 +313,6 @@ class DeepSpeedEngine(Module):
def zero_load_from_fp32_weights(self): def zero_load_from_fp32_weights(self):
return self._config.zero_config.load_from_fp32_weights return self._config.zero_config.load_from_fp32_weights
def allgather_size(self):
return self._config.allgather_size
def fp16_enabled(self): def fp16_enabled(self):
return self._config.fp16_enabled return self._config.fp16_enabled
...@@ -491,6 +491,7 @@ class DeepSpeedEngine(Module): ...@@ -491,6 +491,7 @@ class DeepSpeedEngine(Module):
# Configure optimizer # Configure optimizer
def _configure_optimizer(self, client_optimizer, model_parameters): def _configure_optimizer(self, client_optimizer, model_parameters):
if client_optimizer is not None: if client_optimizer is not None:
basic_optimizer = client_optimizer basic_optimizer = client_optimizer
logger.info('Using client Optimizer as basic optimizer') logger.info('Using client Optimizer as basic optimizer')
...@@ -504,13 +505,14 @@ class DeepSpeedEngine(Module): ...@@ -504,13 +505,14 @@ class DeepSpeedEngine(Module):
if self.zero_optimization(): if self.zero_optimization():
assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2" assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
if self.optimizer_name() != ADAM_OPTIMIZER: if not is_zero_supported_optimizer(basic_optimizer):
assert self.zero_allow_untested_optimizer(), \ assert self.zero_allow_untested_optimizer(), \
'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'
logger.warning( logger.warning(
"**** You are using ZeRO with an untested optimizer, proceed with caution *****" "**** You are using ZeRO with an untested optimizer, proceed with caution *****"
) )
self.optimizer = self._configure_zero_optimizer(basic_optimizer) self.optimizer = self._configure_zero_optimizer(basic_optimizer)
elif self.amp_enabled(): elif self.amp_enabled():
assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode" assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode"
...@@ -522,8 +524,8 @@ class DeepSpeedEngine(Module): ...@@ -522,8 +524,8 @@ class DeepSpeedEngine(Module):
self.optimizer = self._configure_fp16_optimizer(basic_optimizer) self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
else: else:
self.optimizer = basic_optimizer self.optimizer = basic_optimizer
logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer))
# logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict())) logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict()))
def _configure_basic_optimizer(self, model_parameters): def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params() optimizer_parameters = self.optimizer_params()
...@@ -533,8 +535,14 @@ class DeepSpeedEngine(Module): ...@@ -533,8 +535,14 @@ class DeepSpeedEngine(Module):
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details" "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
) )
if self.optimizer_name() == ADAM_OPTIMIZER: if self.optimizer_name() == ADAM_OPTIMIZER:
if self.zero_cpu_offload():
optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
else:
from apex.optimizers.fused_adam import FusedAdam from apex.optimizers.fused_adam import FusedAdam
optimizer = FusedAdam(model_parameters, **optimizer_parameters) optimizer = FusedAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == DEEPSPEED_ADAM:
from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer = DeepSpeedCPUAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER: elif self.optimizer_name() == LAMB_OPTIMIZER:
from deepspeed.ops.lamb import FusedLamb from deepspeed.ops.lamb import FusedLamb
optimizer = FusedLamb(model_parameters, **optimizer_parameters) optimizer = FusedLamb(model_parameters, **optimizer_parameters)
...@@ -550,7 +558,8 @@ class DeepSpeedEngine(Module): ...@@ -550,7 +558,8 @@ class DeepSpeedEngine(Module):
initial_dynamic_scale = self.initial_dynamic_scale() initial_dynamic_scale = self.initial_dynamic_scale()
dynamic_loss_args = self.dynamic_loss_scale_args() dynamic_loss_args = self.dynamic_loss_scale_args()
clip_grad = self.gradient_clipping() clip_grad = self.gradient_clipping()
if self.optimizer_name() == ADAM_OPTIMIZER or self.optimizer_name( if isinstance(optimizer,
apex.optimizers.FusedAdam) or self.optimizer_name(
) == ONEBIT_ADAM_OPTIMIZER: ) == ONEBIT_ADAM_OPTIMIZER:
if self.dynamic_loss_scale(): if self.dynamic_loss_scale():
logger.info('Creating fp16 optimizer with dynamic loss scale') logger.info('Creating fp16 optimizer with dynamic loss scale')
...@@ -616,9 +625,11 @@ class DeepSpeedEngine(Module): ...@@ -616,9 +625,11 @@ class DeepSpeedEngine(Module):
dp_process_group=self.data_parallel_group, dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(), reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(), overlap_comm=self.zero_overlap_comm(),
cpu_offload=self.zero_cpu_offload(),
mpu=self.mpu, mpu=self.mpu,
postscale_gradients=self.postscale_gradients(), postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor()) gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps())
else: else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage)) raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
...@@ -724,7 +735,6 @@ class DeepSpeedEngine(Module): ...@@ -724,7 +735,6 @@ class DeepSpeedEngine(Module):
return loss return loss
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
#Zero stage 2 communicates during non gradient accumulation boundaries as well #Zero stage 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients(): if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue() self.optimizer.overlapping_partition_gradients_reduce_epilogue()
...@@ -780,6 +790,8 @@ class DeepSpeedEngine(Module): ...@@ -780,6 +790,8 @@ class DeepSpeedEngine(Module):
self.timers('backward_inner').start() self.timers('backward_inner').start()
if self.zero_optimization(): if self.zero_optimization():
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary(
)
self.optimizer.backward(loss) self.optimizer.backward(loss)
elif self.amp_enabled(): elif self.amp_enabled():
# AMP requires delaying unscale when inside gradient accumulation boundaries # AMP requires delaying unscale when inside gradient accumulation boundaries
...@@ -854,7 +866,6 @@ class DeepSpeedEngine(Module): ...@@ -854,7 +866,6 @@ class DeepSpeedEngine(Module):
master_params = amp.master_params(self.optimizer) master_params = amp.master_params(self.optimizer)
torch.nn.utils.clip_grad_norm_(parameters=master_params, torch.nn.utils.clip_grad_norm_(parameters=master_params,
max_norm=self.gradient_clipping()) max_norm=self.gradient_clipping())
self.optimizer.step() self.optimizer.step()
#zero grad in basic optimizer could be unreliable and may not exhibit #zero grad in basic optimizer could be unreliable and may not exhibit
...@@ -957,6 +968,9 @@ class DeepSpeedEngine(Module): ...@@ -957,6 +968,9 @@ class DeepSpeedEngine(Module):
def get_lr(self): def get_lr(self):
return self._get_optimizer_param('lr') return self._get_optimizer_param('lr')
def get_type(self):
return self._get_optimizer_param('type')
def get_mom(self): def get_mom(self):
return self._get_optimizer_param('betas') return self._get_optimizer_param('betas')
......
...@@ -5,79 +5,7 @@ Licensed under the MIT license. ...@@ -5,79 +5,7 @@ Licensed under the MIT license.
from deepspeed.runtime.config_utils import get_scalar_param from deepspeed.runtime.config_utils import get_scalar_param
from deepspeed.utils import logger from deepspeed.utils import logger
from deepspeed.runtime.zero.constants import *
#########################################
# ZeRO optimization
#########################################
# ZeRO optimization. By default, this optimization is not enabled.
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
ZERO_FORMAT = '''
ZeRO optimization should be enabled as:
"session_params": {
"zero_optimization": {
"stage": [0|1|2],
"allgather_partitions": [true|false],
"allgather_bucket_size": 500000000,
"reduce_scatter": [true|false],
"contiguous_gradients" : [true|false]
"overlap_comm": [true|false],
"reduce_bucket_size": 500000000
"load_from_fp32_weights": [true|false]
}
}
'''
ZERO_OPTIMIZATION = 'zero_optimization'
ZERO_OPTIMIZATION_DISABLED = 0
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
ZERO_OPTIMIZATION_GRADIENTS = 2
ZERO_OPTIMIZATION_WEIGHTS = 3
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
ZERO_OPTIMIZATION_STAGE = 'stage'
ZERO_OPTIMIZATION_STAGE_1 = 'stage_1'
ZERO_OPTIMIZATION_STAGE_2 = 'stage_2'
ZERO_OPTIMIZATION_STAGE_3 = 'stage_3'
ZERO_OPTIMIZATION_STAGE_DEFAULT = ZERO_OPTIMIZATION_DISABLED
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS = 'allgather_partitions'
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT = True
ZERO_OPTIMIZATION_REDUCE_SCATTER = 'reduce_scatter'
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT = True
ZERO_OPTIMIZATION_OVERLAP_COMM = 'overlap_comm'
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT = False
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS = 'contiguous_gradients'
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT = False
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE = 'reduce_bucket_size'
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE = 'allgather_bucket_size'
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED = 'allgather_size'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS = 'load_from_fp32_weights'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True
ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE:
ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER:
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE:
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT
}
class DeepSpeedZeroConfig(object): class DeepSpeedZeroConfig(object):
...@@ -92,6 +20,7 @@ class DeepSpeedZeroConfig(object): ...@@ -92,6 +20,7 @@ class DeepSpeedZeroConfig(object):
self.allgather_bucket_size = None self.allgather_bucket_size = None
self.overlap_comm = None self.overlap_comm = None
self.load_from_fp32_weights = None self.load_from_fp32_weights = None
self.cpu_offload = None
if ZERO_OPTIMIZATION in param_dict.keys(): if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION] zero_config_dict = param_dict[ZERO_OPTIMIZATION]
...@@ -156,7 +85,12 @@ class DeepSpeedZeroConfig(object): ...@@ -156,7 +85,12 @@ class DeepSpeedZeroConfig(object):
zero_config_dict, zero_config_dict,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT) ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)
self.load_from_fp32_weights = get_scalar_param( self.load_from_fp32_weights = get_scalar_param(
zero_config_dict, zero_config_dict,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT) ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT)
self.cpu_offload = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_CPU_OFFLOAD,
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT)
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
#########################################
# ZeRO optimization
#########################################
# ZeRO optimization. By default, this optimization is not enabled.
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
ZERO_FORMAT = '''
ZeRO optimization should be enabled as:
"session_params": {
"zero_optimization": {
"stage": [0|1|2],
"allgather_partitions": [true|false],
"allgather_bucket_size": 500000000,
"reduce_scatter": [true|false],
"contiguous_gradients" : [true|false]
"overlap_comm": [true|false],
"reduce_bucket_size": 500000000
"load_from_fp32_weights": [true|false]
"cpu_offload": [true|false]
}
}
'''
ZERO_OPTIMIZATION = 'zero_optimization'
ZERO_OPTIMIZATION_DISABLED = 0
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
ZERO_OPTIMIZATION_GRADIENTS = 2
ZERO_OPTIMIZATION_WEIGHTS = 3
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
ZERO_OPTIMIZATION_STAGE = 'stage'
ZERO_OPTIMIZATION_STAGE_1 = 'stage_1'
ZERO_OPTIMIZATION_STAGE_2 = 'stage_2'
ZERO_OPTIMIZATION_STAGE_3 = 'stage_3'
ZERO_OPTIMIZATION_STAGE_DEFAULT = ZERO_OPTIMIZATION_DISABLED
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS = 'allgather_partitions'
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT = True
ZERO_OPTIMIZATION_REDUCE_SCATTER = 'reduce_scatter'
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT = True
ZERO_OPTIMIZATION_OVERLAP_COMM = 'overlap_comm'
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT = False
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS = 'contiguous_gradients'
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT = False
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE = 'reduce_bucket_size'
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE = 'allgather_bucket_size'
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED = 'allgather_size'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS = 'load_from_fp32_weights'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True
ZERO_OPTIMIZATION_CPU_OFFLOAD = 'cpu_offload'
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT = False
ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT,
ZERO_OPTIMIZATION_CPU_OFFLOAD: ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT
}
...@@ -793,8 +793,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -793,8 +793,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
def _get_state_without_padding(self, state_with_padding, padding): def _get_state_without_padding(self, state_with_padding, padding):
lean_state = {} lean_state = {}
for key, value in state_with_padding.items(): for key, value in state_with_padding.items():
if torch.is_tensor(value):
lean_length = value.numel() - padding lean_length = value.numel() - padding
lean_state[key] = value[:lean_length] lean_state[key] = value[:lean_length]
else:
lean_state[key] = value
return lean_state return lean_state
......
...@@ -10,11 +10,14 @@ import math ...@@ -10,11 +10,14 @@ import math
from torch._six import inf from torch._six import inf
from torch.autograd import Variable from torch.autograd import Variable
import collections
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.utils import logger from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.utils import logger
#Toggle this to true to enable correctness test #Toggle this to true to enable correctness test
#with gradient partitioning and without #with gradient partitioning and without
pg_correctness_test = False pg_correctness_test = False
...@@ -122,15 +125,18 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -122,15 +125,18 @@ class FP16_DeepSpeedZeroOptimizer(object):
dp_process_group=None, dp_process_group=None,
reduce_scatter=True, reduce_scatter=True,
overlap_comm=False, overlap_comm=False,
cpu_offload=False,
mpu=None, mpu=None,
clip_grad=0.0, clip_grad=0.0,
allreduce_always_fp32=False, allreduce_always_fp32=False,
postscale_gradients=True, postscale_gradients=True,
gradient_predivide_factor=1.0): gradient_predivide_factor=1.0,
gradient_accumulation_steps=1):
if dist.get_rank() == 0: if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}") logger.info(f"Reduce bucket size {reduce_bucket_size}")
logger.info(f"Allgather bucket size {allgather_bucket_size}") logger.info(f"Allgather bucket size {allgather_bucket_size}")
logger.info(f"CPU Offload: {cpu_offload}")
# The fused optimizer does all the work. We need this layer for two reason: # The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils # 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add ne552w fused optimizer later # 2. keep common stuff here in case we need to add ne552w fused optimizer later
...@@ -150,6 +156,13 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -150,6 +156,13 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.overlap_comm = overlap_comm self.overlap_comm = overlap_comm
self.cpu_offload = cpu_offload
self.deepspeed_adam_offload = (cpu_offload
and type(init_optimizer) == DeepSpeedCPUAdam)
self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu'
self.dp_process_group = dp_process_group self.dp_process_group = dp_process_group
self.partition_count = dist.get_world_size(group=self.dp_process_group) self.partition_count = dist.get_world_size(group=self.dp_process_group)
...@@ -166,6 +179,8 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -166,6 +179,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.allreduce_always_fp32 = allreduce_always_fp32 self.allreduce_always_fp32 = allreduce_always_fp32
self.gradient_predivide_factor = gradient_predivide_factor self.gradient_predivide_factor = gradient_predivide_factor
self.postscale_gradients = postscale_gradients self.postscale_gradients = postscale_gradients
self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = 0
if self.reduce_scatter: if self.reduce_scatter:
assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled" assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled"
...@@ -253,8 +268,8 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -253,8 +268,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
# a partition of the fp32 master weights that will be updated by this process # a partition of the fp32 master weights that will be updated by this process
self.single_partition_of_fp32_groups.append( self.single_partition_of_fp32_groups.append(
self.parallel_partitioned_fp16_groups[i] self.parallel_partitioned_fp16_groups[i][partition_id].to(
[partition_id].clone().float().detach()) self.device).clone().float().detach())
# modify optimizer of have flat master weight # modify optimizer of have flat master weight
self.single_partition_of_fp32_groups[ self.single_partition_of_fp32_groups[
...@@ -275,6 +290,8 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -275,6 +290,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.reduction_stream = torch.cuda.Stream() self.reduction_stream = torch.cuda.Stream()
self.cpu_computation_stream = torch.cuda.Stream()
self.migration_stream = torch.cuda.Stream()
self.callback_queued = False self.callback_queued = False
self.param_dict = {} self.param_dict = {}
...@@ -282,7 +299,8 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -282,7 +299,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
#map between param_id and bool to specify if a param is in this partition #map between param_id and bool to specify if a param is in this partition
self.is_param_in_current_partition = {} self.is_param_in_current_partition = {}
self.contiguous_gradients = contiguous_gradients # CPU-Offload requires contiguous gradients
self.contiguous_gradients = contiguous_gradients or cpu_offload
self.grads_in_ipg_bucket = [] self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = [] self.params_in_ipg_bucket = []
self.elements_in_ipg_bucket = 0 self.elements_in_ipg_bucket = 0
...@@ -293,6 +311,7 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -293,6 +311,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
#simplified param id #simplified param id
self.param_id = {} self.param_id = {}
largest_param_numel = 0
count = 0 count = 0
for i, params_group in enumerate(self.fp16_groups): for i, params_group in enumerate(self.fp16_groups):
for param in params_group: for param in params_group:
...@@ -300,6 +319,8 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -300,6 +319,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.param_id[unique_id] = count self.param_id[unique_id] = count
self.param_dict[count] = param self.param_dict[count] = param
self.params_already_reduced.append(False) self.params_already_reduced.append(False)
if param.numel() > largest_param_numel:
largest_param_numel = param.numel()
count = count + 1 count = count + 1
for param_group in self.params_in_partition: for param_group in self.params_in_partition:
...@@ -310,6 +331,24 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -310,6 +331,24 @@ class FP16_DeepSpeedZeroOptimizer(object):
for param in param_group: for param in param_group:
self.is_param_in_current_partition[self.get_param_id(param)] = False self.is_param_in_current_partition[self.get_param_id(param)] = False
if self.cpu_offload:
self.accumulated_grads_in_cpu = {}
self.norm_for_param_grads = {}
self.local_overflow = False
self.grad_position = {}
self.temp_grad_buffer_for_cpu_offload = torch.zeros(
largest_param_numel,
device=self.device).half().pin_memory()
self.temp_grad_buffer_for_gpu_offload = torch.zeros(
largest_param_numel,
device=torch.cuda.current_device()).half()
for i, params_group in enumerate(self.fp16_groups):
self.get_grad_position(i,
self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i])
#mapping from parameter to partition that it belongs to #mapping from parameter to partition that it belongs to
self.param_to_partition_ids = {} self.param_to_partition_ids = {}
...@@ -382,11 +421,14 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -382,11 +421,14 @@ class FP16_DeepSpeedZeroOptimizer(object):
single_grad_partition = torch.zeros( single_grad_partition = torch.zeros(
int(self.partition_size[i]), int(self.partition_size[i]),
dtype=self.single_partition_of_fp32_groups[i].dtype, dtype=self.single_partition_of_fp32_groups[i].dtype,
device=torch.cuda.current_device()) device=self.device)
self.single_partition_of_fp32_groups[i].grad = single_grad_partition self.single_partition_of_fp32_groups[
i].grad = single_grad_partition.pin_memory(
) if self.cpu_offload else single_grad_partition
self.optimizer.step() self.optimizer.step()
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups: for group in self.single_partition_of_fp32_groups:
group.grad = None group.grad = None
...@@ -444,7 +486,9 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -444,7 +486,9 @@ class FP16_DeepSpeedZeroOptimizer(object):
if self.overlap_comm: if self.overlap_comm:
torch.cuda.synchronize() torch.cuda.synchronize()
if self.cpu_offload is False:
for i, _ in enumerate(self.fp16_groups): for i, _ in enumerate(self.fp16_groups):
if not i in self.averaged_gradients or self.averaged_gradients[i] is None: if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
self.averaged_gradients[i] = self.get_flat_partition( self.averaged_gradients[i] = self.get_flat_partition(
self.params_in_partition[i], self.params_in_partition[i],
...@@ -454,9 +498,6 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -454,9 +498,6 @@ class FP16_DeepSpeedZeroOptimizer(object):
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
return_tensor_list=True) return_tensor_list=True)
else: else:
#When gradient accumulation is greater that 1
#This code path will be triggered and will add
#to the accumulated averaged gradients
avg_new = self.get_flat_partition(self.params_in_partition[i], avg_new = self.get_flat_partition(self.params_in_partition[i],
self.first_offset[i], self.first_offset[i],
self.partition_size[i], self.partition_size[i],
...@@ -628,9 +669,8 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -628,9 +669,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.gradient_predivide_factor() != dp_world_size: if self.gradient_predivide_factor != dp_world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor() / tensor_to_allreduce.mul_(self.gradient_predivide_factor / dp_world_size)
dp_world_size)
else: else:
tensor_to_allreduce.div_(dp_world_size) tensor_to_allreduce.div_(dp_world_size)
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
...@@ -705,7 +745,185 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -705,7 +745,185 @@ class FP16_DeepSpeedZeroOptimizer(object):
for handle in async_handles: for handle in async_handles:
handle.wait() handle.wait()
##############################################################################
############################# CPU Offload Methods#############################
##############################################################################
def get_grad_position(self, group_id, tensor_list, first_offset, partition_size):
current_offset = 0
for i, tensor in enumerate(tensor_list):
param_id = self.get_param_id(tensor)
param_start_offset = 0
num_elements = tensor.numel()
tensor_offset = 0
#we need to offset to get to the right element
if i == 0 and first_offset > 0:
tensor_offset = first_offset
num_elements = num_elements - tensor_offset
param_start_offset = first_offset
#we dont need all elements of the tensor
if num_elements > (partition_size - current_offset):
num_elements = partition_size - current_offset
self.grad_position[param_id] = [
int(group_id),
int(param_start_offset),
int(current_offset),
int(num_elements)
]
current_offset += num_elements
def update_overflow_tracker_for_param_grad(self, param):
if param.grad is not None and self._has_inf_or_nan(param.grad.data):
self.local_overflow = True
def async_accumulate_grad_in_cpu(self, param):
param_id = self.get_param_id(param)
#copy to a preexisiting buffer to avoid memory allocation penalty
dest_buffer = self.temp_grad_buffer_for_cpu_offload.view(-1).narrow(
0,
0,
param.numel())
dest_buffer.copy_(param.grad.view(-1), non_blocking=True)
if param_id not in self.accumulated_grads_in_cpu:
self.accumulated_grads_in_cpu[param_id] = torch.zeros(
param.numel(),
dtype=param.dtype,
device=self.device).pin_memory()
self.accumulated_grads_in_cpu[param_id].add_(dest_buffer)
def async_accumulate_grad_in_cpu_via_gpu(self, param):
param_id = self.get_param_id(param)
#copy to a preexisiting buffer to avoid memory allocation penalty
dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(
0,
0,
param.numel())
if param_id not in self.accumulated_grads_in_cpu:
self.accumulated_grads_in_cpu[param_id] = torch.zeros(
param.numel(),
dtype=param.dtype,
device=self.device).pin_memory()
if self.micro_step_id > 0:
dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1),
non_blocking=True)
param.grad.data.view(-1).add_(dest_buffer)
#at the boundary we will send 32bit directly
if not self.is_gradient_accumulation_boundary:
self.accumulated_grads_in_cpu[param_id].data.copy_(param.grad.data.view(-1),
non_blocking=True)
def set_norm_for_param_grad(self, param):
param_id = self.get_param_id(param)
accumulated_grad = self.accumulated_grads_in_cpu[
param_id] if self.gradient_accumulation_steps > 1 else param.grad
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
start = source_offset
accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements)
self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2)
def set_norm_for_param_grad_in_gpu(self, param):
param_id = self.get_param_id(param)
accumulated_grad = param.grad
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
start = source_offset
accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements)
self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2)
def async_inplace_copy_grad_to_fp32_buffer(self, param):
param_id = self.get_param_id(param)
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(
0,
dest_offset,
num_elements)
if self.gradient_accumulation_steps > 1:
src_tensor = self.accumulated_grads_in_cpu[param_id].view(-1).narrow(
0,
source_offset,
num_elements)
else:
src_tensor = param.grad.view(-1).narrow(0,
source_offset,
num_elements).float()
dest_tensor.copy_(src_tensor, non_blocking=True)
def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param):
param_id = self.get_param_id(param)
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(
0,
dest_offset,
num_elements)
src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements).float()
dest_tensor.copy_(src_tensor, non_blocking=True)
param.grad = None
def complete_grad_norm_calculation_for_cpu_offload(self, params):
total_norm = 0.0
norm_type = 2.0
for p in params:
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_id = self.get_param_id(p)
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item()**2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
############################################################################################
def copy_grads_in_partition(self, param): def copy_grads_in_partition(self, param):
if self.cpu_offload:
if self.gradient_accumulation_steps > 1:
self.async_accumulate_grad_in_cpu_via_gpu(param)
if self.is_gradient_accumulation_boundary:
self.set_norm_for_param_grad_in_gpu(param)
self.update_overflow_tracker_for_param_grad(param)
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
return
#print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}")
if self.grads_in_partition is None: if self.grads_in_partition is None:
self.grads_in_partition_offset = 0 self.grads_in_partition_offset = 0
total_size = 0 total_size = 0
...@@ -720,11 +938,13 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -720,11 +938,13 @@ class FP16_DeepSpeedZeroOptimizer(object):
see_memory_usage(f"after copying {total_size} gradients into partition") see_memory_usage(f"after copying {total_size} gradients into partition")
#The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer #The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer
new_grad_tensor = self.grads_in_partition.narrow(0, new_grad_tensor = self.grads_in_partition.view(-1).narrow(
0,
self.grads_in_partition_offset, self.grads_in_partition_offset,
param.numel()) param.numel())
new_grad_tensor.copy_(param.grad.view(-1)) new_grad_tensor.copy_(param.grad.view(-1))
param.grad.data = new_grad_tensor.data.view_as(param.grad) param.grad.data = new_grad_tensor.data.view_as(param.grad)
#print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}")
self.grads_in_partition_offset += param.numel() self.grads_in_partition_offset += param.numel()
def reduce_ipg_grads(self): def reduce_ipg_grads(self):
...@@ -1104,10 +1324,19 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1104,10 +1324,19 @@ class FP16_DeepSpeedZeroOptimizer(object):
for p in param_list: for p in param_list:
p.grad = None p.grad = None
def reset_cpu_buffers(self):
self.norm_for_param_grads = {}
self.local_overflow = False
def step(self, closure=None): def step(self, closure=None):
""" """
Not supporting closure. Not supporting closure.
""" """
self.micro_step_id = -1
if self.cpu_offload:
torch.cuda.current_stream().wait_stream(self.migration_stream)
see_memory_usage(f"In step before checking overflow") see_memory_usage(f"In step before checking overflow")
# First compute norm for all group so we know if there is overflow # First compute norm for all group so we know if there is overflow
...@@ -1120,8 +1349,10 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1120,8 +1349,10 @@ class FP16_DeepSpeedZeroOptimizer(object):
if self.overflow: if self.overflow:
see_memory_usage('After overflow before clearing gradients') see_memory_usage('After overflow before clearing gradients')
self.zero_grad() self.zero_grad()
for key in self.averaged_gradients: if self.cpu_offload:
self.averaged_gradients[key] = None self.reset_cpu_buffers()
else:
self.averaged_gradients = {}
see_memory_usage('After overflow after clearing gradients') see_memory_usage('After overflow after clearing gradients')
...@@ -1130,18 +1361,26 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1130,18 +1361,26 @@ class FP16_DeepSpeedZeroOptimizer(object):
"reducing to {}".format(dist.get_rank(), "reducing to {}".format(dist.get_rank(),
prev_scale, prev_scale,
self.loss_scale)) self.loss_scale))
timers('optimizer_gradients').start()
timers('optimizer_gradients').stop()
timers('optimizer_step').start() timers('optimizer_step').start()
timers('optimizer_step').stop() timers('optimizer_step').stop()
timers('optimizer_allgather').start() timers('optimizer_allgather').start()
timers('optimizer_allgather').stop() timers('optimizer_allgather').stop()
return return
timers('optimizer_gradients').start()
norm_groups = [] norm_groups = []
single_partition_grad_groups = [] single_partition_grad_groups = []
skip = False skip = False
partition_id = dist.get_rank(group=self.dp_process_group) partition_id = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups): for i, group in enumerate(self.fp16_groups):
if self.cpu_offload:
norm_groups.append(
self.complete_grad_norm_calculation_for_cpu_offload(
self.params_in_partition[i]))
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
else:
norm_groups.append( norm_groups.append(
self.get_grad_norm_direct(self.averaged_gradients[i], self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i])) self.params_in_partition[i]))
...@@ -1172,17 +1411,31 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1172,17 +1411,31 @@ class FP16_DeepSpeedZeroOptimizer(object):
single_partition_grad_groups.append(single_grad_partition) single_partition_grad_groups.append(single_grad_partition)
self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups) self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
timers('optimizer_gradients').stop()
#torch.set_num_threads(12)
timers('optimizer_step').start() timers('optimizer_step').start()
if self.deepspeed_adam_offload:
self.optimizer.step(fp16_param_groups=self.parallel_partitioned_fp16_groups)
#self.optimizer.step()
#for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
# fp16_partitions[partition_id].data.copy_(fp32_partition.data)
else:
self.optimizer.step() self.optimizer.step()
#get rid of the fp32 gradients. Not needed anymore #get rid of the fp32 gradients. Not needed anymore
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups: for group in self.single_partition_of_fp32_groups:
group.grad = None group.grad = None
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data) fp16_partitions[partition_id].data.copy_(fp32_partition.data)
timers('optimizer_step').stop() timers('optimizer_step').stop()
if self.cpu_offload:
self.reset_cpu_buffers()
timers('optimizer_allgather').start() timers('optimizer_allgather').start()
#gather the updated weights from everyone #gather the updated weights from everyone
for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups): for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups):
...@@ -1225,6 +1478,10 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1225,6 +1478,10 @@ class FP16_DeepSpeedZeroOptimizer(object):
for p, q in zip(self.fp16_groups[i], updated_params): for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data p.data = q.data
timers.log(
names=['optimizer_gradients',
'optimizer_step',
'optimizer_allgather'])
see_memory_usage('After zero_optimizer step') see_memory_usage('After zero_optimizer step')
return return
...@@ -1270,7 +1527,8 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1270,7 +1527,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
def has_overflow(self, partition_gradients=True): def has_overflow(self, partition_gradients=True):
if partition_gradients: if partition_gradients:
overflow = self.has_overflow_partitioned_grads_serial() overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial(
)
overflow_gpu = torch.cuda.ByteTensor([overflow]) overflow_gpu = torch.cuda.ByteTensor([overflow])
torch.distributed.all_reduce(overflow_gpu, torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
...@@ -1323,16 +1581,20 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1323,16 +1581,20 @@ class FP16_DeepSpeedZeroOptimizer(object):
2. scaled_loss = fp32_loss*loss_scale 2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
""" """
self.micro_step_id += 1
if self.cpu_offload:
torch.cuda.current_stream().wait_stream(self.migration_stream)
if self.contiguous_gradients: if self.contiguous_gradients:
self.ipg_buffer = [] self.ipg_buffer = []
buf_0 = torch.empty(self.reduce_bucket_size, buf_0 = torch.empty(int(self.reduce_bucket_size * 4.5),
dtype=torch.half, dtype=torch.half,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
self.ipg_buffer.append(buf_0) self.ipg_buffer.append(buf_0)
# Use double buffers to avoid data access conflict when overlap_comm is enabled. # Use double buffers to avoid data access conflict when overlap_comm is enabled.
if self.overlap_comm: if self.overlap_comm:
buf_1 = torch.empty(self.reduce_bucket_size, buf_1 = torch.empty(int(self.reduce_bucket_size * 4.5),
dtype=torch.half, dtype=torch.half,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
self.ipg_buffer.append(buf_1) self.ipg_buffer.append(buf_1)
...@@ -1389,8 +1651,11 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1389,8 +1651,11 @@ class FP16_DeepSpeedZeroOptimizer(object):
def _get_state_without_padding(self, state_with_padding, padding): def _get_state_without_padding(self, state_with_padding, padding):
lean_state = {} lean_state = {}
for key, value in state_with_padding.items(): for key, value in state_with_padding.items():
if torch.is_tensor(value):
lean_length = value.numel() - padding lean_length = value.numel() - padding
lean_state[key] = value[:lean_length] lean_state[key] = value[:lean_length]
else:
lean_state[key] = value
return lean_state return lean_state
...@@ -1432,6 +1697,12 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1432,6 +1697,12 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.single_partition_of_fp32_groups) self.single_partition_of_fp32_groups)
state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding
# if self.cpu_offload:
# state_dict_tmp = async_copy_to(state_dict,
# 'cpu',
# torch.cuda.current_stream())
# state_dict = state_dict_tmp
return state_dict return state_dict
# Restore base optimizer fp32 weights from checkpoint by: # Restore base optimizer fp32 weights from checkpoint by:
...@@ -1468,11 +1739,15 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1468,11 +1739,15 @@ class FP16_DeepSpeedZeroOptimizer(object):
def _partition_base_optimizer_state(self, state_key, all_partition_states): def _partition_base_optimizer_state(self, state_key, all_partition_states):
partition_id = dist.get_rank(group=self.dp_process_group) partition_id = dist.get_rank(group=self.dp_process_group)
alignment = dist.get_world_size(group=self.dp_process_group) alignment = dist.get_world_size(group=self.dp_process_group)
if torch.is_tensor(all_partition_states[0]):
flat_merged_partitions = flatten_dense_tensors_aligned( flat_merged_partitions = flatten_dense_tensors_aligned(
all_partition_states, all_partition_states,
alignment) alignment)
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions) dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions)
return dp_partitions[partition_id] return dp_partitions[partition_id]
else:
# Assume non-tensor states are not partitioned and equal across ranks, so return first one
return all_partition_states[0]
# Restore base optimizer state from checkpoint by # Restore base optimizer state from checkpoint by
# 1) Merging optimizer state from checkpoints of all partitions # 1) Merging optimizer state from checkpoints of all partitions
...@@ -1497,8 +1772,10 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1497,8 +1772,10 @@ class FP16_DeepSpeedZeroOptimizer(object):
for i, group in enumerate(self.optimizer.param_groups): for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0] p = group['params'][0]
for key, saved in base_optimizer_group_states[i].items(): for key, saved in base_optimizer_group_states[i].items():
current = self.optimizer.state[p][key] if torch.is_tensor(self.optimizer.state[p][key]):
current.data.copy_(saved.data) self.optimizer.state[p][key].data.copy_(saved.data)
else:
self.optimizer.state[p][key] = saved
def load_state_dict(self, def load_state_dict(self,
state_dict_list, state_dict_list,
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import apex
from deepspeed.utils import logger from deepspeed.utils import logger
from deepspeed.ops.adam import DeepSpeedCPUAdam
def _initialize_parameter_parallel_groups(parameter_parallel_size=None): def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
...@@ -20,3 +21,17 @@ def _initialize_parameter_parallel_groups(parameter_parallel_size=None): ...@@ -20,3 +21,17 @@ def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
if rank in ranks: if rank in ranks:
my_group = group my_group = group
return my_group return my_group
ZERO_SUPPORTED_OPTIMIZERS = [
torch.optim.Adam,
apex.optimizers.FusedAdam,
DeepSpeedCPUAdam
]
def is_zero_supported_optimizer(optimizer):
print(
f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}'
)
return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS
...@@ -162,6 +162,14 @@ Please see the [core API doc](https://deepspeed.readthedocs.io/) for more detail ...@@ -162,6 +162,14 @@ Please see the [core API doc](https://deepspeed.readthedocs.io/) for more detail
With DeepSpeed, the user can choose to use a high performance implementation of ADAM from With DeepSpeed, the user can choose to use a high performance implementation of ADAM from
NVIDIA, or any training optimizer that extends torch's `torch.optim.Optimizer` class. NVIDIA, or any training optimizer that extends torch's `torch.optim.Optimizer` class.
### CPU-Adam: High-Performance vectorized implementation of Adam
We introduce an efficient implementation of Adam optimizer on CPU that improves the parameter-update
performance by nearly an order of magnitude. We use the AVX SIMD instructions on Intel-x86 architecture
for the CPU-Adam implementation. We support both AVX-512 and AVX-2 instruction sets. DeepSpeed uses
AVX-2 by defualt which can be switched to AVX-512 by setting the build flag, `DS_BUILD_AVX512` to 1 when
installing DeepSpeed. Using AVX-512, we observe 5.1x to 6.5x speedups considering the model-size between
1 to 10 billion parameters with respect to torch-adam.
### Memory bandwidth optimized FP16 Optimizer ### Memory bandwidth optimized FP16 Optimizer
Mixed precision training is handled by the DeepSpeed FP16 Optimizer. This optimizer not Mixed precision training is handled by the DeepSpeed FP16 Optimizer. This optimizer not
only handles FP16 training but is also highly efficient. The performance of weight update only handles FP16 training but is also highly efficient. The performance of weight update
......
...@@ -167,6 +167,7 @@ overview](/features/) for descriptions and usage. ...@@ -167,6 +167,7 @@ overview](/features/) for descriptions and usage.
* Automatic loss scaling with mixed precision * Automatic loss scaling with mixed precision
* [Training Optimizers](/features/#training-optimizers) * [Training Optimizers](/features/#training-optimizers)
* Fused Adam optimizer and arbitrary `torch.optim.Optimizer` * Fused Adam optimizer and arbitrary `torch.optim.Optimizer`
* CPU-Adam: High-Performance vectorized Adam
* Memory bandwidth optimized FP16 Optimizer * Memory bandwidth optimized FP16 Optimizer
* Large Batch Training with LAMB Optimizer * Large Batch Training with LAMB Optimizer
* Memory efficient Training with ZeRO Optimizer * Memory efficient Training with ZeRO Optimizer
......
...@@ -164,10 +164,10 @@ if [ ! -f $hostfile ]; then ...@@ -164,10 +164,10 @@ if [ ! -f $hostfile ]; then
local_only=1 local_only=1
fi fi
#if [ "$skip_requirements" == "0" ]; then if [ "$skip_requirements" == "0" ]; then
# # Ensure dependencies are installed locally # Ensure dependencies are installed locally
# $PIP_SUDO $PIP_INSTALL -r requirements.txt $PIP_SUDO $PIP_INSTALL -r requirements/requirements.txt
#fi fi
# Build wheels # Build wheels
if [ "$third_party_install" == "1" ]; then if [ "$third_party_install" == "1" ]; then
...@@ -220,10 +220,10 @@ else ...@@ -220,10 +220,10 @@ else
tmp_wheel_path="/tmp/deepspeed_wheels" tmp_wheel_path="/tmp/deepspeed_wheels"
pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; else mkdir -pv $tmp_wheel_path; fi" pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; else mkdir -pv $tmp_wheel_path; fi"
#pdcp -w $hosts requirements/*.txt ${tmp_wheel_path}/ pdcp -w $hosts requirements/requirements.txt ${tmp_wheel_path}/
#if [ "$skip_requirements" == "0" ]; then if [ "$skip_requirements" == "0" ]; then
# pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL -r ${tmp_wheel_path}/requirements.txt" pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL -r ${tmp_wheel_path}/requirements.txt"
#fi fi
if [ "$third_party_install" == "1" ]; then if [ "$third_party_install" == "1" ]; then
pdsh -w $hosts "$PIP_SUDO pip uninstall -y apex" pdsh -w $hosts "$PIP_SUDO pip uninstall -y apex"
pdcp -w $hosts third_party/apex/dist/apex*.whl $tmp_wheel_path/ pdcp -w $hosts third_party/apex/dist/apex*.whl $tmp_wheel_path/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment