cpu_adagrad.cpp 8.24 KB
Newer Older
aiss's avatar
aiss committed
1
2
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
aiss's avatar
aiss committed
3

aiss's avatar
aiss committed
4
5
6
// DeepSpeed Team

#include "cpu_adagrad.h"
aiss's avatar
aiss committed
7
8
9
10
11
#include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
aiss's avatar
aiss committed
12
13
#if defined(__ENABLE_CUDA__)
#include <cuda_runtime_api.h>
aiss's avatar
aiss committed
14
15
16
17
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#include "custom_cuda_layers.h"
aiss's avatar
aiss committed
18
#endif
aiss's avatar
aiss committed
19
20
21
22
23
24
25
26
27

static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;

// C++ interface

void Adagrad_Optimizer::Step_1(float* _params,
                               float* grads,
                               float* _exp_avg_sq,
                               size_t _param_size,
aiss's avatar
aiss committed
28
                               ds_half_precision_t* dev_params,
aiss's avatar
aiss committed
29
30
31
32
33
34
35
36
37
                               bool half_precision)
{
    size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
    Step_AVX<1>(
        &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
#endif
    if (_param_size > rounded_size) {
        float step_size = -1 * _alpha;
aiss's avatar
aiss committed
38
39
        ds_half_precision_t* grads_cast_h;
        ds_half_precision_t* params_cast_h;
aiss's avatar
aiss committed
40
        if (half_precision) {
aiss's avatar
aiss committed
41
42
            grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
            params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
aiss's avatar
aiss committed
43
44
45
46
47
        }
        for (size_t t = rounded_size; t < _param_size; t += TILE) {
            size_t copy_size = TILE;
            if ((t + TILE) > _param_size) copy_size = _param_size - t;
            size_t offset = copy_size + t;
aiss's avatar
aiss committed
48
#if defined(__ENABLE_CUDA__)
aiss's avatar
aiss committed
49
            if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
aiss's avatar
aiss committed
50
#endif
aiss's avatar
aiss committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#pragma omp parallel for
            for (size_t k = t; k < offset; k++) {
                float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
                float param = half_precision ? (float)params_cast_h[k] : _params[k];
                float momentum = grads[k];
                float variance = _exp_avg_sq[k];
                if (_weight_decay > 0) { grad = param * _weight_decay + grad; }

                variance += grad * grad;

                grad = sqrt(variance);
                grad += _eps;
                grad = momentum / grad;
                param = grad * step_size + param;
aiss's avatar
aiss committed
65
#if defined(__ENABLE_CUDA__)
aiss's avatar
aiss committed
66
                if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
aiss's avatar
aiss committed
67
#endif
aiss's avatar
aiss committed
68
                if (half_precision)
aiss's avatar
aiss committed
69
                    params_cast_h[k] = (ds_half_precision_t)param;
aiss's avatar
aiss committed
70
71
72
73
74
75
                else
                    _params[k] = param;
                // STORE UPDATE TERM TO GRAD'S MEMORY
                grads[k] = grad * step_size;
                _exp_avg_sq[k] = variance;
            }
aiss's avatar
aiss committed
76
#if defined(__ENABLE_CUDA__)
aiss's avatar
aiss committed
77
78
79
80
81
            if (dev_params) {
                launch_param_update(
                    _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
                _buf_index = !_buf_index;
            }
aiss's avatar
aiss committed
82
#endif
aiss's avatar
aiss committed
83
84
85
86
87
88
89
90
        }
    }
}

void Adagrad_Optimizer::Step_4(float* _params,
                               float* grads,
                               float* _exp_avg_sq,
                               size_t _param_size,
aiss's avatar
aiss committed
91
                               ds_half_precision_t* dev_params,
aiss's avatar
aiss committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
                               bool half_precision)
{
    size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
    Step_AVX<4>(
        &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
#endif
    if (_param_size > rounded_size)
        Step_1((_params + rounded_size),
               (grads + rounded_size),
               (_exp_avg_sq + rounded_size),
               (_param_size - rounded_size),
               (dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
               half_precision);
}

int create_adagrad_optimizer(int optimizer_id,
                             float alpha = 1e-2,
                             float eps = 1e-8,
                             float weight_decay = 0,
                             bool should_log = false)
{
    auto opt = std::make_shared<Adagrad_Optimizer>(alpha, eps, weight_decay);

    s_optimizers[optimizer_id] = opt;

    if (should_log) {
        std::string avx_type = "";
#if defined(__AVX512__)
        avx_type = "AVX512";
#else
#if defined(__AVX256__)
        avx_type = "AVX2";
#else
        avx_type = "scalar";
#endif
#endif

        printf("Adagrad Optimizer #%d is created with %s arithmetic capability.\n",
               optimizer_id,
               avx_type.c_str());
        printf("Config: alpha=%f, weight_decay=%f\n", alpha, weight_decay);
    }

    return 0;
}

void Adagrad_Optimizer::Step_8(float* _params,
                               float* grads,
                               float* _exp_avg_sq,
                               size_t _param_size,
aiss's avatar
aiss committed
143
                               ds_half_precision_t* dev_params,
aiss's avatar
aiss committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
                               bool half_precision)
{
    size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
    Step_AVX<8>(
        &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
#endif
    if (_param_size > rounded_size)
        Step_4((_params + rounded_size),
               (grads + rounded_size),
               (_exp_avg_sq + rounded_size),
               (_param_size - rounded_size),
               (dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
               half_precision);
}

int ds_adagrad_step(int optimizer_id,
                    size_t step,
                    float lr,
                    float epsilon,
                    float weight_decay,
                    torch::Tensor& params,
                    torch::Tensor& grads,
                    torch::Tensor& exp_avg_sq)
{
    auto params_c = params.contiguous();
    auto grads_c = grads.contiguous();
    auto exp_avg_sq_c = exp_avg_sq.contiguous();

    float* params_ptr = (float*)params_c.data_ptr();
    float* grads_ptr = (float*)grads_c.data_ptr();
    float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();

    std::shared_ptr<Adagrad_Optimizer> opt =
        std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
    opt->IncrementStep(step);
    opt->update_state(lr, epsilon, weight_decay);
aiss's avatar
aiss committed
181
    opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel());
aiss's avatar
aiss committed
182

aiss's avatar
aiss committed
183
#if defined(__ENABLE_CUDA__)
aiss's avatar
aiss committed
184
    opt->SynchronizeStreams();
aiss's avatar
aiss committed
185
#endif
aiss's avatar
aiss committed
186
187
188
189
190
191
192
193
194
195
196
197
198
    return 0;
}

int ds_adagrad_step_plus_copy(int optimizer_id,
                              size_t step,
                              float lr,
                              float epsilon,
                              float weight_decay,
                              torch::Tensor& params,
                              torch::Tensor& grads,
                              torch::Tensor& exp_avg_sq,
                              torch::Tensor& gpu_params)
{
aiss's avatar
aiss committed
199
#if defined(__ENABLE_CUDA__)
aiss's avatar
aiss committed
200
201
202
203
204
205
206
    auto params_c = params.contiguous();
    auto gpu_params_c = gpu_params.contiguous();
    auto exp_avg_sq_c = exp_avg_sq.contiguous();
    auto grads_c = grads.contiguous();

    float* params_ptr = (float*)params_c.data_ptr();
    float* grads_ptr = (float*)grads_c.data_ptr();
aiss's avatar
aiss committed
207
    ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
aiss's avatar
aiss committed
208
209
210
211
212
213
214
215
216
    float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();

    std::shared_ptr<Adagrad_Optimizer> opt =
        std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
    opt->IncrementStep(step);
    opt->update_state(lr, epsilon, weight_decay);
    opt->Step_8(params_ptr,
                grads_ptr,
                exp_avg_sq_ptr,
aiss's avatar
aiss committed
217
                params_c.numel(),
aiss's avatar
aiss committed
218
219
220
221
                gpu_params_ptr,
                (params.options().dtype() == at::kHalf));

    opt->SynchronizeStreams();
aiss's avatar
aiss committed
222
223
224
#else
    assert(false);
#endif
aiss's avatar
aiss committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    return 0;
}

int destroy_adagrad_optimizer(int optimizer_id)
{
    s_optimizers.erase(optimizer_id);

    return 0;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("adagrad_update", &ds_adagrad_step, "DeepSpeed CPU Adagrad update (C++)");
    m.def("adagrad_update_copy",
          &ds_adagrad_step_plus_copy,
          "DeepSpeed CPU Adagrad update and param copy (C++)");
    m.def("create_adagrad", &create_adagrad_optimizer, "DeepSpeed CPU Adagrad (C++)");
    m.def("destroy_adagrad", &destroy_adagrad_optimizer, "DeepSpeed CPU Adagrad destroy (C++)");
}