"vscode:/vscode.git/clone" did not exist on "cc3cbe9f6f291af172252f097952bfe247200195"
cpu_adam.cpp 18.1 KB
Newer Older
LuGY's avatar
LuGY committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/*
Copyright (c) Microsoft Corporation.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
*/
#include "cpu_adam.h"
23

LuGY's avatar
LuGY committed
24
25
#include <math.h>
#include <omp.h>
26
#include <string.h>
27
28
29

#include <iostream>
#include <memory>
LuGY's avatar
LuGY committed
30
31
32
33
34
#include <type_traits>
#include <unordered_map>

// C++ interface

35
36
37
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
                            float *_exp_avg_sq, size_t _param_size,
                            bool param_half_precision, bool grad_half_precision,
38
39
40
                            bool momentum_half_precision,
                            bool variance_half_precision, float loss_scale) {
  size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
LuGY's avatar
LuGY committed
41

42
43
44
45
  float betta1_minus1 = 1 - _betta1;
  float betta2_minus1 = 1 - _betta2;
  float step_size = -1 * _alpha / _bias_correction1;
  float w_decay = -1 * _alpha * _weight_decay;
LuGY's avatar
LuGY committed
46

47
48
49
50
  __half *params_cast_h = reinterpret_cast<__half *>(_params);
  __half *grads_cast_h = reinterpret_cast<__half *>(grads);
  __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
  __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
LuGY's avatar
LuGY committed
51
52

#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
53
54
55
56
  AVX_Data betta1_4;
  betta1_4.data = SIMD_SET(_betta1);
  AVX_Data betta2_4;
  betta2_4.data = SIMD_SET(_betta2);
LuGY's avatar
LuGY committed
57

58
59
60
61
  AVX_Data betta1_minus1_4;
  betta1_minus1_4.data = SIMD_SET(betta1_minus1);
  AVX_Data betta2_minus1_4;
  betta2_minus1_4.data = SIMD_SET(betta2_minus1);
LuGY's avatar
LuGY committed
62

63
64
  AVX_Data bias2_sqrt;
  bias2_sqrt.data = SIMD_SET(_bias_correction2);
LuGY's avatar
LuGY committed
65

66
67
  AVX_Data eps_4;
  eps_4.data = SIMD_SET(_eps);
LuGY's avatar
LuGY committed
68

69
70
  AVX_Data step_size_4;
  step_size_4.data = SIMD_SET(step_size);
LuGY's avatar
LuGY committed
71

72
73
74
75
  AVX_Data weight_decay_4;
  if (_weight_decay > 0)
    weight_decay_4.data =
        (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
LuGY's avatar
LuGY committed
76

77
78
  for (size_t t = 0; t < rounded_size; t += TILE) {
    size_t copy_size = TILE;
79
    if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
80
    size_t offset = copy_size + t;
LuGY's avatar
LuGY committed
81
82

#pragma omp parallel for
83
84
    for (size_t i = t; i < offset; i += SIMD_WIDTH) {
      AVX_Data grad_4;
85
      this->simd_load(grad_half_precision, grads + i, grads_cast_h + i, grad_4);
86
87
88
89
90
91
      if (loss_scale > 0) {
        AVX_Data loss_scale_vec;
        loss_scale_vec.data = SIMD_SET(loss_scale);
        grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data);
      }
      AVX_Data momentum_4;
92
93
      this->simd_load(momentum_half_precision, _exp_avg + i,
                      momentum_cast_h + i, momentum_4);
94
95

      AVX_Data variance_4;
96
97
      this->simd_load(variance_half_precision, _exp_avg_sq + i,
                      variance_cast_h + i, variance_4);
98
99

      AVX_Data param_4;
100
101
      this->simd_load(param_half_precision, _params + i, params_cast_h + i,
                      param_4);
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

      if (_weight_decay > 0 && !_adamw_mode) {
        grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data);
      }
      momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
      momentum_4.data =
          SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);
      variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data);
      grad_4.data = SIMD_MUL(grad_4.data, grad_4.data);
      variance_4.data =
          SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data);
      grad_4.data = SIMD_SQRT(variance_4.data);
      grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
      grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);

      if (_weight_decay > 0 && _adamw_mode) {
        param_4.data =
            SIMD_FMA(param_4.data, weight_decay_4.data, param_4.data);
      }
      param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);

123
124
125
126
127
128
      this->simd_store(param_half_precision, _params + i, params_cast_h + i,
                       param_4);
      this->simd_store(momentum_half_precision, _exp_avg + i,
                       momentum_cast_h + i, momentum_4);
      this->simd_store(variance_half_precision, _exp_avg_sq + i,
                       variance_cast_h + i, variance_4);
LuGY's avatar
LuGY committed
129
    }
130
  }
LuGY's avatar
LuGY committed
131
#endif
132
133
134
  if (_param_size > rounded_size) {
    for (size_t t = rounded_size; t < _param_size; t += TILE) {
      size_t copy_size = TILE;
135
      if ((t + TILE) > _param_size) copy_size = _param_size - t;
136
      size_t offset = copy_size + t;
LuGY's avatar
LuGY committed
137
138

#pragma omp parallel for
139
140
141
142
143
144
145
      for (size_t k = t; k < offset; k++) {
        float grad = grad_half_precision ? (float)grads_cast_h[k] : grads[k];
        if (loss_scale > 0) {
          grad /= loss_scale;
        }
        float param =
            param_half_precision ? (float)params_cast_h[k] : _params[k];
146
147
148
149
        float momentum =
            momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k];
        float variance = variance_half_precision ? (float)variance_cast_h[k]
                                                 : _exp_avg_sq[k];
150
151
        if (_weight_decay > 0 && !_adamw_mode) {
          grad = param * _weight_decay + grad;
LuGY's avatar
LuGY committed
152
        }
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        momentum = momentum * _betta1;
        momentum = grad * betta1_minus1 + momentum;

        variance = variance * _betta2;
        grad = grad * grad;
        variance = grad * betta2_minus1 + variance;

        grad = sqrt(variance);
        grad = grad * _bias_correction2 + _eps;
        grad = momentum / grad;
        if (_weight_decay > 0 && _adamw_mode) {
          param += w_decay * param;
        }
        param = grad * step_size + param;

        if (param_half_precision)
          params_cast_h[k] = (__half)param;
        else
          _params[k] = param;
172
173
174
175
176
177
178
179
        if (momentum_half_precision)
          momentum_cast_h[k] = (__half)(momentum);
        else
          _exp_avg[k] = momentum;
        if (variance_half_precision)
          variance_cast_h[k] = (__half)(variance);
        else
          _exp_avg_sq[k] = variance;
180
      }
LuGY's avatar
LuGY committed
181
    }
182
  }
LuGY's avatar
LuGY committed
183
184
}

185
186
187
void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
                            float *_exp_avg_sq, size_t _param_size,
                            bool param_half_precision, bool grad_half_precision,
188
189
190
                            bool momentum_half_precision,
                            bool variance_half_precision, float loss_scale) {
  size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
191

192
193
194
195
  __half *params_cast_h = reinterpret_cast<__half *>(_params);
  __half *grads_cast_h = reinterpret_cast<__half *>(grads);
  __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
  __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
LuGY's avatar
LuGY committed
196
197

#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
  AVX_Data betta1_4;
  betta1_4.data = SIMD_SET(_betta1);
  AVX_Data betta2_4;
  betta2_4.data = SIMD_SET(_betta2);

  float betta1_minus1 = 1 - _betta1;
  AVX_Data betta1_minus1_4;
  betta1_minus1_4.data = SIMD_SET(betta1_minus1);
  float betta2_minus1 = 1 - _betta2;
  AVX_Data betta2_minus1_4;
  betta2_minus1_4.data = SIMD_SET(betta2_minus1);

  AVX_Data bias2_sqrt;
  bias2_sqrt.data = SIMD_SET(_bias_correction2);

  AVX_Data eps_4;
  eps_4.data = SIMD_SET(_eps);

  float step_size = -1 * _alpha / _bias_correction1;
  AVX_Data step_size_4;
  step_size_4.data = SIMD_SET(step_size);

  float w_decay = -1 * _alpha * _weight_decay;
  AVX_Data weight_decay_4;
  if (_weight_decay > 0)
    weight_decay_4.data =
        (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));

  for (size_t t = 0; t < rounded_size; t += TILE) {
    size_t copy_size = TILE;
228
    if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
229
    size_t offset = copy_size + t;
LuGY's avatar
LuGY committed
230
231

#pragma omp parallel for
232
233
234
235
236
    for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
      AVX_Data grad_4[4];
      AVX_Data momentum_4[4];
      AVX_Data variance_4[4];
      AVX_Data param_4[4];
LuGY's avatar
LuGY committed
237
#pragma unroll 4
238
      for (int j = 0; j < 4; j++) {
239
240
        this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
                        grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
241
242
243
244
245
246

        if (loss_scale > 0) {
          AVX_Data loss_scale_vec;
          loss_scale_vec.data = SIMD_SET(loss_scale);
          grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
        }
247
248
249
250
251
252
253
        this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
                        momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
        this->simd_load(variance_half_precision,
                        _exp_avg_sq + i + SIMD_WIDTH * j,
                        variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
        this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
                        params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
254
255
256
257

        if (_weight_decay > 0 && !_adamw_mode) {
          grad_4[j].data =
              SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
LuGY's avatar
LuGY committed
258
        }
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
        momentum_4[j].data =
            SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
        variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
        grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
        variance_4[j].data =
            SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
        grad_4[j].data = SIMD_SQRT(variance_4[j].data);
        grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
        grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);

        if (_weight_decay > 0 && _adamw_mode) {
          param_4[j].data =
              SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
        }
        param_4[j].data =
            SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
276
277
278
279
280
281
282
        this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
                         params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
        this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
                         momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
        this->simd_store(variance_half_precision,
                         _exp_avg_sq + i + SIMD_WIDTH * j,
                         variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
283
      }
LuGY's avatar
LuGY committed
284
    }
285
  }
LuGY's avatar
LuGY committed
286
#endif
287
288
289
290
291
  if (_param_size > rounded_size)
    Step_1((param_half_precision ? (float *)(params_cast_h + rounded_size)
                                 : _params + rounded_size),
           (grad_half_precision ? (float *)(grads_cast_h + rounded_size)
                                : grads + rounded_size),
292
293
294
295
           (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
                                    : _exp_avg + rounded_size),
           (variance_half_precision ? (float *)(variance_cast_h + rounded_size)
                                    : _exp_avg_sq + rounded_size),
296
           (_param_size - rounded_size), param_half_precision,
297
298
           grad_half_precision, momentum_half_precision,
           variance_half_precision, loss_scale);
LuGY's avatar
LuGY committed
299
300
}

301
302
303
void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
                            float *_exp_avg_sq, size_t _param_size,
                            bool param_half_precision, bool grad_half_precision,
304
305
306
307
308
309
310
311
                            bool momentum_half_precision,
                            bool variance_half_precision, float loss_scale) {
  size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
  __half *params_cast_h = reinterpret_cast<__half *>(_params);
  __half *grads_cast_h = reinterpret_cast<__half *>(grads);
  __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
  __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);

LuGY's avatar
LuGY committed
312
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
  AVX_Data betta1_4;
  betta1_4.data = SIMD_SET(_betta1);
  AVX_Data betta2_4;
  betta2_4.data = SIMD_SET(_betta2);

  float betta1_minus1 = 1 - _betta1;
  AVX_Data betta1_minus1_4;
  betta1_minus1_4.data = SIMD_SET(betta1_minus1);
  float betta2_minus1 = 1 - _betta2;
  AVX_Data betta2_minus1_4;
  betta2_minus1_4.data = SIMD_SET(betta2_minus1);

  AVX_Data bias2_sqrt;
  bias2_sqrt.data = SIMD_SET(_bias_correction2);

  AVX_Data eps_4;
  eps_4.data = SIMD_SET(_eps);

  float step_size = -1 * _alpha / _bias_correction1;
  AVX_Data step_size_4;
  step_size_4.data = SIMD_SET(step_size);

  float w_decay = -1 * _alpha * _weight_decay;
  AVX_Data weight_decay_4;
  if (_weight_decay > 0)
    weight_decay_4.data =
        (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));

  for (size_t t = 0; t < rounded_size; t += TILE) {
    size_t copy_size = TILE;
343
    if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
344
    size_t offset = copy_size + t;
LuGY's avatar
LuGY committed
345
346

#pragma omp parallel for
347
348
349
350
351
    for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
      AVX_Data grad_4[8];
      AVX_Data momentum_4[8];
      AVX_Data variance_4[8];
      AVX_Data param_4[8];
LuGY's avatar
LuGY committed
352
#pragma unroll 8
353
      for (int j = 0; j < 8; j++) {
354
355
        this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
                        grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
356
357
358
359
360

        if (loss_scale > 0) {
          AVX_Data loss_scale_vec;
          loss_scale_vec.data = SIMD_SET(loss_scale);
          grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
LuGY's avatar
LuGY committed
361
        }
362
363
364
365
366
367
368
        this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
                        momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
        this->simd_load(variance_half_precision,
                        _exp_avg_sq + i + SIMD_WIDTH * j,
                        variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
        this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
                        params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390

        if (_weight_decay > 0 && !_adamw_mode) {
          grad_4[j].data =
              SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
        }
        momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
        momentum_4[j].data =
            SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
        variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
        grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
        variance_4[j].data =
            SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
        grad_4[j].data = SIMD_SQRT(variance_4[j].data);
        grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
        grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
        if (_weight_decay > 0 && _adamw_mode) {
          param_4[j].data =
              SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
        }
        param_4[j].data =
            SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);

391
392
393
394
395
396
397
        this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
                         params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
        this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
                         momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
        this->simd_store(variance_half_precision,
                         _exp_avg_sq + i + SIMD_WIDTH * j,
                         variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
398
      }
LuGY's avatar
LuGY committed
399
    }
400
  }
LuGY's avatar
LuGY committed
401
#endif
402
403
404
405
406
  if (_param_size > rounded_size)
    Step_4((param_half_precision ? (float *)(params_cast_h + rounded_size)
                                 : _params + rounded_size),
           (grad_half_precision ? (float *)(grads_cast_h + rounded_size)
                                : grads + rounded_size),
407
408
409
410
           (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
                                    : _exp_avg + rounded_size),
           (variance_half_precision ? (float *)(variance_cast_h + rounded_size)
                                    : _exp_avg_sq + rounded_size),
411
           (_param_size - rounded_size), param_half_precision,
412
413
           grad_half_precision, momentum_half_precision,
           variance_half_precision, loss_scale);
LuGY's avatar
LuGY committed
414
415
}

416
417
418
419
420
void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
                          float epsilon, float weight_decay,
                          bool bias_correction, torch::Tensor &params,
                          torch::Tensor &grads, torch::Tensor &exp_avg,
                          torch::Tensor &exp_avg_sq, float loss_scale) {
421
422
423
424
425
426
427
428
429
  auto params_c = params.contiguous();
  auto grads_c = grads.contiguous();
  auto exp_avg_c = exp_avg.contiguous();
  auto exp_avg_sq_c = exp_avg_sq.contiguous();

  float *params_ptr = (float *)params_c.data_ptr();
  float *grads_ptr = (float *)grads_c.data_ptr();
  float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
  float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
LuGY's avatar
LuGY committed
430

431
432
433
434
  this->IncrementStep(step, beta1, beta2);
  this->update_state(lr, epsilon, weight_decay, bias_correction);
  this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
               params_c.numel(), (params.options().dtype() == at::kHalf),
435
436
437
               (grads.options().dtype() == at::kHalf),
               (exp_avg.options().dtype() == at::kHalf),
               (exp_avg_sq.options().dtype() == at::kHalf), loss_scale);
LuGY's avatar
LuGY committed
438
439
}

440
441
namespace py = pybind11;

442
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
443
444
445
  py::class_<Adam_Optimizer>(m, "CPUAdamOptimizer")
      .def(py::init<float, float, float, float, float, bool>())
      .def("step", &Adam_Optimizer::step);
LuGY's avatar
LuGY committed
446
}