"sgl-kernel/vscode:/vscode.git/clone" did not exist on "bde24ab31f89f56516616ba40df074fd1117679e"
multi_tensor_adagrad.cu 2.98 KB
Newer Older
Andrew Tulloch's avatar
Andrew Tulloch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>

#include <assert.h>

#include "multi_tensor_apply.cuh"
#include "type_shim.h"

#define BLOCK_SIZE 1024
#define ILP 4

typedef enum {
  ADAGRAD_MODE_0 = 0, // L2 regularization mode.
  ADAGRAD_MODE_1 = 1, // AdamW-style weight decay.

} adagradMode_t;

using MATH_T = float;

template <typename T> struct AdagradFunctor {
  __device__ __forceinline__ void
26
  operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl,
Andrew Tulloch's avatar
Andrew Tulloch committed
27
28
             const float epsilon, const float lr, adagradMode_t mode,
             const float weight_decay) {
29
30
31
    int tensor_loc = tl.block_to_tensor[blockIdx.x];
    int chunk_idx = tl.block_to_chunk[blockIdx.x];
    int n = tl.sizes[tensor_loc];
Andrew Tulloch's avatar
Andrew Tulloch committed
32

33
    T *g = (T *)tl.addresses[0][tensor_loc];
Andrew Tulloch's avatar
Andrew Tulloch committed
34
35
    g += chunk_idx * chunk_size;

36
    T *p = (T *)tl.addresses[1][tensor_loc];
Andrew Tulloch's avatar
Andrew Tulloch committed
37
38
    p += chunk_idx * chunk_size;

39
    T *h = (T *)tl.addresses[2][tensor_loc];
Andrew Tulloch's avatar
Andrew Tulloch committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    h += chunk_idx * chunk_size;

    n -= chunk_idx * chunk_size;

    // see note in multi_tensor_scale_kernel.cu
    for (int i_start = 0; i_start < n && i_start < chunk_size;
         i_start += blockDim.x * ILP) {
      MATH_T r_g[ILP];
      MATH_T r_p[ILP];
      MATH_T r_h[ILP];
#pragma unroll
      for (int ii = 0; ii < ILP; ii++) {
        int i = i_start + threadIdx.x + ii * blockDim.x;
        if (i < n && i < chunk_size) {
          r_g[ii] = g[i];
          r_p[ii] = p[i];
          r_h[ii] = h[i];
        } else {
          r_g[ii] = MATH_T(0);
          r_p[ii] = MATH_T(0);
          r_h[ii] = MATH_T(0);
        }
      }
#pragma unroll
      for (int ii = 0; ii < ILP; ii++) {
        if (mode == ADAGRAD_MODE_0) { // L2
          r_g[ii] = r_g[ii] + weight_decay * r_p[ii];
          r_h[ii] = r_h[ii] + r_g[ii] * r_g[ii];
          r_p[ii] = r_p[ii] - lr * (r_g[ii] / (sqrtf(r_h[ii]) + epsilon));
        } else { // AdamW-style
          r_h[ii] = r_h[ii] + r_g[ii] * r_g[ii];
          r_p[ii] = r_p[ii] - lr * (r_g[ii] / (sqrtf(r_h[ii]) + epsilon) + weight_decay * r_p[ii]);
        }
      }
#pragma unroll
      for (int ii = 0; ii < ILP; ii++) {
        int i = i_start + threadIdx.x + ii * blockDim.x;
        if (i < n && i < chunk_size) {
          p[i] = r_p[ii];
          h[i] = r_h[ii];
        }
      }
    }
  }
};

void multi_tensor_adagrad_cuda(
    int chunk_size, at::Tensor noop_flag,
    std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
    const float epsilon, const int mode, const float weight_decay) {
  using namespace at;

  // Assume single type across p,g,h now
rohithkrn's avatar
rohithkrn committed
93
  DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(
Andrew Tulloch's avatar
Andrew Tulloch committed
94
95
96
97
98
99
100
      tensor_lists[0][0].scalar_type(), 0, "adagrad",
      multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
                            AdagradFunctor<scalar_t_0>(), epsilon, lr,
                            (adagradMode_t)mode, weight_decay);)

  AT_CUDA_CHECK(cudaGetLastError());
}