cpu_adagrad.h 5.48 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team
aiss's avatar
aiss committed
5

aiss's avatar
aiss committed
6
7
8
9
10
11
12
#pragma once

#define NOMINMAX  // Windows idiosyncrasy
                  // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c

#include <stdio.h>
#include <cassert>
aiss's avatar
aiss committed
13
14
15
16
17
#include "simd.h"

#if defined(__ENABLE_CUDA__)
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
aiss's avatar
aiss committed
18
19
#include "cuda.h"
#include "custom_cuda_layers.h"
aiss's avatar
aiss committed
20
21
22
23
typedef __half ds_half_precision_t;
#else
typedef unsigned short ds_half_precision_t;
#endif
aiss's avatar
aiss committed
24

aiss's avatar
aiss committed
25
26
27
28
29
30
#define STEP(SPAN)                                             \
    void Step_##SPAN(float* _params,                           \
                     float* grads,                             \
                     float* _exp_avg_sq,                       \
                     size_t _param_size,                       \
                     ds_half_precision_t* dev_param = nullptr, \
aiss's avatar
aiss committed
31
32
33
34
35
                     bool half_precision = false);

class Adagrad_Optimizer {
public:
    Adagrad_Optimizer(float alpha = 1e-2, float eps = 1e-8, float weight_decay = 0)
aiss's avatar
aiss committed
36
        : _alpha(alpha), _eps(eps), _weight_decay(weight_decay)
aiss's avatar
aiss committed
37
    {
aiss's avatar
aiss committed
38
#if defined(__ENABLE_CUDA__)
aiss's avatar
aiss committed
39
40
41
        cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
        cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));

aiss's avatar
aiss committed
42
43
        _streams[0] = TrainingContext::Instance().GetCurrentStream();
        _streams[1] = TrainingContext::Instance().GetNewStream();
aiss's avatar
aiss committed
44
45
        _buf_index = false;
#endif
aiss's avatar
aiss committed
46
47
48
    }
    ~Adagrad_Optimizer()
    {
aiss's avatar
aiss committed
49
#if defined(__ENABLE_CUDA__)
aiss's avatar
aiss committed
50
51
        cudaFreeHost(_doubled_buffer[0]);
        cudaFreeHost(_doubled_buffer[1]);
aiss's avatar
aiss committed
52
#endif
aiss's avatar
aiss committed
53
54
55
56
57
58
59
60
    }
#if defined(__AVX512__) or defined(__AVX256__)
    template <int span>
    void Step_AVX(size_t* rounded_size,
                  float* _params,
                  float* grads,
                  float* _exp_avg_sq,
                  size_t param_size,
aiss's avatar
aiss committed
61
                  ds_half_precision_t* dev_param = nullptr,
aiss's avatar
aiss committed
62
63
64
65
66
                  bool half_precision = false);
#endif
    STEP(1)
    STEP(4)
    STEP(8)
aiss's avatar
aiss committed
67
#if defined(__ENABLE_CUDA__)
aiss's avatar
aiss committed
68
69
70
71
    inline void SynchronizeStreams()
    {
        for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
    }
aiss's avatar
aiss committed
72
#endif
aiss's avatar
aiss committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    inline void IncrementStep(size_t step)
    {
        _step++;
        if (_step != step) { _step = step; }
    }
    inline void update_state(float lr, float epsilon, float weight_decay)
    {
        _alpha = lr;
        _eps = epsilon;
        _weight_decay = weight_decay;
    }

private:
    float _alpha;
    float _eps;
    float _weight_decay;

    float _betta1_t;
    float _betta2_t;
    size_t _step;

aiss's avatar
aiss committed
94
#if defined(__ENABLE_CUDA__)
aiss's avatar
aiss committed
95
    bool _buf_index;
aiss's avatar
aiss committed
96
    float* _doubled_buffer[2];
aiss's avatar
aiss committed
97
    cudaStream_t _streams[2];
aiss's avatar
aiss committed
98
#endif
aiss's avatar
aiss committed
99
100
101
102
103
104
105
106
107
};

#if defined(__AVX512__) or defined(__AVX256__)
template <int span>
void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
                                 float* _params,
                                 float* grads,
                                 float* _exp_avg_sq,
                                 size_t _param_size,
aiss's avatar
aiss committed
108
                                 ds_half_precision_t* dev_params,
aiss's avatar
aiss committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
                                 bool half_precision)
{
    size_t new_rounded_size = 0;
    AVX_Data eps_4;
    eps_4.data = SIMD_SET(_eps);

    float step_size = -1 * _alpha;
    AVX_Data step_size_4;
    step_size_4.data = SIMD_SET(step_size);

    AVX_Data weight_decay4;
    if (_weight_decay > 0) weight_decay4.data = SIMD_SET(_weight_decay);
    new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span);
    for (size_t t = 0; t < new_rounded_size; t += TILE) {
        size_t copy_size = TILE;
        if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
        size_t offset = copy_size + t;
aiss's avatar
aiss committed
126
#if defined(__ENABLE_CUDA__)
aiss's avatar
aiss committed
127
        if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
aiss's avatar
aiss committed
128
#endif
aiss's avatar
aiss committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#pragma omp parallel for
        for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
            AVX_Data grad_4[span];
            simd_load<span>(grad_4, grads + i, half_precision);

            AVX_Data momentum_4[span];
            simd_load<span>(momentum_4, grads + i, false);

            AVX_Data variance_4[span];
            simd_load<span>(variance_4, _exp_avg_sq + i, false);

            AVX_Data param_4[span];
            simd_load<span>(param_4, _params + i, half_precision);

            if (_weight_decay > 0) { simd_fma<span>(grad_4, param_4, weight_decay4, grad_4); }

            simd_fma<span>(variance_4, grad_4, grad_4, variance_4);
            simd_sqrt<span>(grad_4, variance_4);
            simd_add<span>(grad_4, grad_4, eps_4);
            simd_div<span>(grad_4, momentum_4, grad_4);
            simd_fma<span>(param_4, grad_4, step_size_4, param_4);

            simd_store<span>(_params + i, param_4, half_precision);
aiss's avatar
aiss committed
152
#if defined(__ENABLE_CUDA__)
aiss's avatar
aiss committed
153
154
155
            if (dev_params) {
                simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
            }
aiss's avatar
aiss committed
156
#endif
aiss's avatar
aiss committed
157
158
            simd_store<span>(_exp_avg_sq + i, variance_4, false);
        }
aiss's avatar
aiss committed
159
#if defined(__ENABLE_CUDA__)
aiss's avatar
aiss committed
160
161
162
163
164
165
166
167
168
169
        if (dev_params) {
            if (half_precision)
                launch_param_update_half(
                    _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
            else
                launch_param_update(
                    _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);

            _buf_index = !_buf_index;
        }
aiss's avatar
aiss committed
170
#endif
aiss's avatar
aiss committed
171
172
173
174
    }
    *rounded_size = new_rounded_size;
}
#endif