multi_tensor_sgd_kernel.cu 6.88 KB
Newer Older
Simon Layton's avatar
Simon Layton committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"

#include <assert.h>
#include <cuda_runtime.h>

#define BLOCK_SIZE 512
#define ILP 4

/**
 * Perform fused SGD on multiple buffers
15
 * N: number of tensors
Simon Layton's avatar
Simon Layton committed
16
17
18
 * tl[0] : gradients
 * tl[1] : weights
 * tl[2] : momentum buffers
19
 * tl[3] : fp16 weights (if appropriate)
Simon Layton's avatar
Simon Layton committed
20
21
22
23
24
25
 * wd : weight_decay (scalar)
 * momentum : momentum (scalar)
 * dampening : momentum dampening (scalar)
 * lr : learning rate (scalar)
 * nesterov : enable nesterov (bool)
 * first run : necessary for proper momentum handling & init
26
 * wd_after_momentum : apply weight decay _after_ momentum instead of before
Simon Layton's avatar
Simon Layton committed
27
 **/
28
template<int N, typename T_grad, typename T_weight>
Simon Layton's avatar
Simon Layton committed
29
30
31
32
33
struct SGDFunctor
{
   __device__ __forceinline__ void operator()(
    int chunk_size,
    volatile int* noop_gmem,
34
    TensorList<N>& tl,
Simon Layton's avatar
Simon Layton committed
35
36
37
38
39
    float wd,
    float momentum,
    float dampening,
    float lr,
    bool nesterov,
40
41
    bool first_run,
    bool wd_after_momentum)
Simon Layton's avatar
Simon Layton committed
42
  {
Simon Layton's avatar
Simon Layton committed
43
44
    // Early exit if we don't need to do anything
    if (*noop_gmem) return;
Simon Layton's avatar
Simon Layton committed
45
46
47
48
49

    int tensor_loc = tl.block_to_tensor[blockIdx.x];
    int chunk_idx = tl.block_to_chunk[blockIdx.x];
    int n = tl.sizes[tensor_loc];

50
    T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];
Simon Layton's avatar
Simon Layton committed
51
52
    grad_in += chunk_idx*chunk_size;
   
53
    T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc];
Simon Layton's avatar
Simon Layton committed
54
55
    weight_in += chunk_idx*chunk_size;

56
    T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc];
Simon Layton's avatar
Simon Layton committed
57
58
    mom_in += chunk_idx*chunk_size;

59
    at::Half *model_weights_out = nullptr;
60
    if (N == 4) {
61
      model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
62
63
64
      model_weights_out += chunk_idx*chunk_size;
    }

Simon Layton's avatar
Simon Layton committed
65
66
67
    n -= chunk_idx*chunk_size;

    // Non-divergent exit condition for the __syncthreads
68
69
70
    float incoming_grads[ILP];
    float incoming_weights[ILP];
    float incoming_moms[ILP];
Simon Layton's avatar
Simon Layton committed
71
72
73
74
75
76
77
78
79
80
81
82
    for(int i_start = 0;
        i_start < n && i_start < chunk_size;
        i_start += blockDim.x*ILP)
    {
      #pragma unroll
      for(int ii = 0; ii < ILP; ii++)
      {
        incoming_grads[ii] = 0;
        incoming_weights[ii] = 0;
        incoming_moms[ii] = 0;
        int i = i_start + threadIdx.x + ii*blockDim.x;
        if(i < n && i < chunk_size)
83
84
85
          incoming_grads[ii] = static_cast<float>(grad_in[i]);
          incoming_weights[ii] = static_cast<float>(weight_in[i]);
          incoming_moms[ii] = static_cast<float>(mom_in[i]);
Simon Layton's avatar
Simon Layton committed
86
87
88
89
90
91
92
93
94
95
96
97
      }

      // note for clarification to future michael:
      // From a pure memory dependency perspective, there's likely no point unrolling
      // the write loop, since writes just fire off once their LDGs arrive.
      // Put another way, the STGs are dependent on the LDGs, but not on each other.
      // There is still compute ILP benefit from unrolling the loop though.
      #pragma unroll
      for(int ii = 0; ii < ILP; ii++)
      {
        int i = i_start + threadIdx.x + ii*blockDim.x;
        if(i < n && i < chunk_size) {
98
99
          // apply weight decay before momentum if necessary
          if (wd != 0.f && !wd_after_momentum) {
Simon Layton's avatar
Simon Layton committed
100
101
102
103
104
105
106
107
108
            incoming_grads[ii] += wd * incoming_weights[ii];
          }
          if (momentum != 0.f) {
            if (!first_run) {
              incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
            }

            if (nesterov) {
              incoming_grads[ii] += momentum * incoming_moms[ii];
109
110
            } else {
              incoming_grads[ii] = incoming_moms[ii];
Simon Layton's avatar
Simon Layton committed
111
112
113
            }
          }

114
115
116
117
118
          // Apply WD after momentum if desired
          if (wd != 0.f && wd_after_momentum) {
            incoming_grads[ii] += wd * incoming_weights[ii];
          }

Simon Layton's avatar
Simon Layton committed
119
120
121
          // adjust the weight and write out
          weight_in[i] += (-lr * incoming_grads[ii]);

122
123
124
125
126
          // if necessary, write out an fp16 copy of the weights
          if (N == 4) {
            model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
          }

Simon Layton's avatar
Simon Layton committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
          // also write out the new momentum
          if (momentum != 0.f) {
            mom_in[i] = incoming_moms[ii];
          }
        }
      }
    }
  }
};

void multi_tensor_sgd_cuda(
  int chunk_size,
  at::Tensor noop_flag,
  std::vector<std::vector<at::Tensor>> tensor_lists,
  float wd,
  float momentum,
  float dampening,
  float lr,
  bool nesterov,
146
147
  bool first_run,
  bool wd_after_momentum)
Simon Layton's avatar
Simon Layton committed
148
{
149
  auto num_tensors = tensor_lists.size();
Simon Layton's avatar
Simon Layton committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
  auto grad_type = tensor_lists[0][0].type().scalarType();
  auto weight_type = tensor_lists[0][0].type().scalarType();

  // We have 4 potentials to handle here, in terms of
  // grad_type, param_type, momentum_type, requires_fp16_copy
  // 1. fp16, fp16, fp16, No
  // 2. fp16, fp32, fp32, No
  // 3. fp16, fp32, fp32, Yes
  // 4. fp32, fp32, fp32, No
  // It's easier to hardcode these possibilities than to use
  // switches etc. to handle the cross-product of cases where
  // we don't want the majority of them.

  // Case 1. fp16, fp16, fp16, No
  if (grad_type == at::ScalarType::Half &&
      weight_type == at::ScalarType::Half &&
      num_tensors == 3) {
    multi_tensor_apply<3>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<3, at::Half, at::Half>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
178
179
        first_run,
        wd_after_momentum);
Simon Layton's avatar
Simon Layton committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
  }
  // Case 2. fp16, fp32, fp32, No
  else if (grad_type == at::ScalarType::Half &&
           weight_type == at::ScalarType::Float &&
           num_tensors == 3) {
    multi_tensor_apply<3>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<3, at::Half, float>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
196
197
        first_run,
        wd_after_momentum);
Simon Layton's avatar
Simon Layton committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
  }
  // Case 3. fp16, fp32, fp32, Yes
  else if (grad_type == at::ScalarType::Half &&
           weight_type == at::ScalarType::Float &&
           num_tensors == 4) {
    multi_tensor_apply<4>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<4, at::Half, float>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
214
215
        first_run,
        wd_after_momentum);
Simon Layton's avatar
Simon Layton committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
  }
  // Case 4. fp32, fp32, fp32, No
  else if (grad_type == at::ScalarType::Float &&
      weight_type == at::ScalarType::Float &&
      num_tensors == 3) {
    multi_tensor_apply<3>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<3, float, float>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
232
233
        first_run,
        wd_after_momentum);
Simon Layton's avatar
Simon Layton committed
234
235
236
237
  }
  else {
    AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
             "gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
238
  }
Simon Layton's avatar
Simon Layton committed
239
240
241

  AT_CUDA_CHECK(cudaGetLastError());
}