multi_tensor_sgd_kernel.cu 6.33 KB
Newer Older
1
2
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
zbian's avatar
zbian committed
3
4
5
6
7
8
9
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <cuda_runtime.h>

10
11
12
#include "compat.h"
#include "multi_tensor_apply.cuh"

zbian's avatar
zbian committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#define BLOCK_SIZE 512
#define ILP 4

/**
 * Perform fused SGD on multiple buffers
 * N: number of tensors
 * tl[0] : gradients
 * tl[1] : weights
 * tl[2] : momentum buffers
 * tl[3] : fp16 weights (if appropriate)
 * wd : weight_decay (scalar)
 * momentum : momentum (scalar)
 * dampening : momentum dampening (scalar)
 * lr : learning rate (scalar)
 * nesterov : enable nesterov (bool)
 * first run : necessary for proper momentum handling & init
 * wd_after_momentum : apply weight decay _after_ momentum instead of before
 **/
ver217's avatar
ver217 committed
31
template <typename T_grad, typename T_weight>
32
33
struct SGDFunctor {
  __device__ __forceinline__ void operator()(
ver217's avatar
ver217 committed
34
      int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl,
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
      float wd, float momentum, float dampening, float lr, bool nesterov,
      bool first_run, bool wd_after_momentum, float scale) {
    // Early exit if we don't need to do anything
    if (*noop_gmem) return;

    int tensor_loc = tl.block_to_tensor[blockIdx.x];
    int chunk_idx = tl.block_to_chunk[blockIdx.x];
    int n = tl.sizes[tensor_loc];

    T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc];
    grad_in += chunk_idx * chunk_size;

    T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc];
    weight_in += chunk_idx * chunk_size;

    T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc];
    mom_in += chunk_idx * chunk_size;

    n -= chunk_idx * chunk_size;
54

55
56
57
58
59
60
    // Non-divergent exit condition for the __syncthreads
    float incoming_grads[ILP];
    float incoming_weights[ILP];
    float incoming_moms[ILP];
    for (int i_start = 0; i_start < n && i_start < chunk_size;
         i_start += blockDim.x * ILP) {
61
#pragma unroll
62
63
64
65
66
67
68
69
70
71
72
      for (int ii = 0; ii < ILP; ii++) {
        incoming_grads[ii] = 0;
        incoming_weights[ii] = 0;
        incoming_moms[ii] = 0;
        int i = i_start + threadIdx.x + ii * blockDim.x;
        if (i < n && i < chunk_size) {
          incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;
          incoming_weights[ii] = static_cast<float>(weight_in[i]);
          incoming_moms[ii] = static_cast<float>(mom_in[i]);
        }
      }
zbian's avatar
zbian committed
73
74
75
76
77
78
79

// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
      for (int ii = 0; ii < ILP; ii++) {
        int i = i_start + threadIdx.x + ii * blockDim.x;
        if (i < n && i < chunk_size) {
          // apply weight decay before momentum if necessary
          if (wd != 0.f && !wd_after_momentum)
            incoming_grads[ii] += wd * incoming_weights[ii];

          if (momentum != 0.f) {
            if (!first_run)
              incoming_moms[ii] = incoming_moms[ii] * momentum +
                                  (1.f - dampening) * incoming_grads[ii];
            else  // initialize momentums to current incoming grads
              incoming_moms[ii] = incoming_grads[ii];

            if (nesterov)
              incoming_grads[ii] += momentum * incoming_moms[ii];
            else
              incoming_grads[ii] = incoming_moms[ii];
          }

          // Apply WD after momentum if desired
          if (wd != 0.f && wd_after_momentum)
            incoming_grads[ii] += wd * incoming_weights[ii];

          // adjust the weight and write out
          weight_in[i] += (-lr * incoming_grads[ii]);

          // also write out the new momentum
          if (momentum != 0.f) mom_in[i] = incoming_moms[ii];
zbian's avatar
zbian committed
109
        }
110
      }
zbian's avatar
zbian committed
111
    }
112
  }
zbian's avatar
zbian committed
113
114
};

115
116
117
118
119
120
121
122
123
124
125
126
127
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
                           std::vector<std::vector<at::Tensor>> tensor_lists,
                           float wd, float momentum, float dampening, float lr,
                           bool nesterov, bool first_run,
                           bool wd_after_momentum, float scale) {
  auto num_tensors = tensor_lists.size();
  auto grad_type = tensor_lists[0][0].scalar_type();
  auto weight_type = tensor_lists[1][0].scalar_type();

  TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(),
              "expected noop flag to be on the same device as tensors");

  // We have 3 possibilities to handle here, in terms of
ver217's avatar
ver217 committed
128
129
130
131
  // grad_type, param_type, momentum_type
  // 1. fp16, fp16, fp16
  // 2. fp32, fp32, fp32
  // 3. fp16, fp32, fp32
132
133
134
135
136
137
138
139
  // It's easier to hardcode these possibilities than to use
  // switches etc. to handle the cross-product of cases where
  // we don't want the majority of them.

  // Case 1. fp16, fp16, fp16, No
  if (grad_type == at::ScalarType::Half &&
      weight_type == at::ScalarType::Half && num_tensors == 3) {
    multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ver217's avatar
ver217 committed
140
                          SGDFunctor<at::Half, at::Half>(), wd, momentum,
141
142
143
                          dampening, lr, nesterov, first_run, wd_after_momentum,
                          scale);
  }
ver217's avatar
ver217 committed
144
  // Case 2. fp32, fp32, fp32
145
146
147
  else if (grad_type == at::ScalarType::Float &&
           weight_type == at::ScalarType::Float && num_tensors == 3) {
    multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ver217's avatar
ver217 committed
148
149
                          SGDFunctor<float, float>(), wd, momentum, dampening,
                          lr, nesterov, first_run, wd_after_momentum, scale);
150
  }
ver217's avatar
ver217 committed
151
  // Case 3. fp16, fp32, fp32
152
  else if (grad_type == at::ScalarType::Half &&
ver217's avatar
ver217 committed
153
154
155
           weight_type == at::ScalarType::Float && num_tensors == 3) {
    multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
                          SGDFunctor<at::Half, float>(), wd, momentum,
156
157
158
159
160
161
162
163
164
165
166
                          dampening, lr, nesterov, first_run, wd_after_momentum,
                          scale);
  } else {
    AT_ERROR(
        "multi_tensor_sgd only supports some combinations of gradient & weight "
        "types. Given: ",
        "gradient: ", grad_type, ", weight: ", weight_type,
        ", num_lists: ", num_tensors);
  }

  AT_CUDA_CHECK(cudaGetLastError());
167
}