batch_norm.cu 11.6 KB
Newer Older
jjsjann123's avatar
jjsjann123 committed
1
2
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDACachingAllocator.h>
jjsjann123's avatar
jjsjann123 committed
4
5
6
7
8

#include "batch_norm.h"

#include <cuda.h>

mcarilli's avatar
mcarilli committed
9
10
#include "compat.h"

jjsjann123's avatar
jjsjann123 committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#define cudaCheckErrors(msg) \
    do { \
        cudaError_t __err = cudaGetLastError(); \
        if (__err != cudaSuccess) { \
            fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
                msg, cudaGetErrorString(__err), \
                __FILE__, __LINE__); \
            fprintf(stderr, "*** FAILED - ABORTING\n"); \
            exit(1); \
        } \
    } while (0)

static size_t round_up_to_multiple(size_t x, int multiple) {
  return ((x + multiple - 1) / multiple) * multiple;
}

struct Workspace {
  Workspace(size_t size) : size(size), data(NULL) {
29
30
31
    auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
    dataPtr = allocator.allocate(size);
    data = dataPtr.get();
jjsjann123's avatar
jjsjann123 committed
32
33
34
35
  }
  Workspace(const Workspace&) = delete;
  Workspace(Workspace&&) = default;
  Workspace& operator=(Workspace&&) = default;
36
  ~Workspace() = default;
jjsjann123's avatar
jjsjann123 committed
37
38
39

  size_t size;
  void* data;
40
  c10::DataPtr dataPtr;
jjsjann123's avatar
jjsjann123 committed
41
42
43
44
45
46
47
48
49
50
51
};

// Return {y}
at::Tensor nhwc_bn_fwd_train(
                       const at::Tensor& x,
                       const at::Tensor& scale,
                       const at::Tensor& bias,
                       const at::Tensor& running_mean,
                       const at::Tensor& running_inv_var,
                       const at::Tensor& minibatch_mean,
                       const at::Tensor& minibatch_inv_var,
Evgeni Krimer's avatar
Evgeni Krimer committed
52
                       const at::Tensor& ret_cta,
jjsjann123's avatar
jjsjann123 committed
53
54
55
56
57
58
                       const float momentum,
                       const float epsilon,
                       const bool fuse_relu,
                       void * my_data,
                       void * pair_data,
                       void * pair_data2,
Evgeni Krimer's avatar
Evgeni Krimer committed
59
                       void * pair_data3,
jjsjann123's avatar
jjsjann123 committed
60
61
                       const int bn_group,
                       const at::Tensor& magic_tensor,
Evgeni Krimer's avatar
Evgeni Krimer committed
62
63
64
                       const int occupancy,
                       const int grid_dim_x,
                       const bool coop) {
jjsjann123's avatar
jjsjann123 committed
65
66
67
68
69
70
71

  const int N = x.size(0);
  const int H = x.size(1);
  const int W = x.size(2);
  const int C = x.size(3);

  // generating new magic number and use that for sync
mcarilli's avatar
mcarilli committed
72
  int* magic = magic_tensor.DATA_PTR<int>();
jjsjann123's avatar
jjsjann123 committed
73
74
75
76
77
78
79
80
  *magic = (*magic + 1) & 0xff;

  // Allocate output tensor
  at::Tensor y = at::empty({N, H, W, C}, x.options());

  // Create wrapper
  NhwcBatchNorm *bn = new NhwcBatchNorm();

81
82
  bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
  bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
jjsjann123's avatar
jjsjann123 committed
83
84
85
86

  bn->setConstants(momentum, epsilon);

  // set pointers within the wrapper
87
  bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(),
jjsjann123's avatar
jjsjann123 committed
88
                             nullptr,
mcarilli's avatar
mcarilli committed
89
                             y.DATA_PTR<at::Half>(),
jjsjann123's avatar
jjsjann123 committed
90
91
                             nullptr);

92
93
94
95
  bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
                         bias.contiguous().DATA_PTR<float>()}, {nullptr, nullptr});
  bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
                            running_inv_var.DATA_PTR<float>()});
jjsjann123's avatar
jjsjann123 committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

  // deal with workspace(s)
  auto workspace_bytes = bn->numWorkspaceBytes();
  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
  // an allocated workspace for the others
  size_t total_workspace_bytes = 0;
  std::vector<size_t> workspace_offsets;

  for (auto index = 3; index < workspace_bytes.size(); ++index) {
    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
    workspace_offsets.push_back(total_workspace_bytes);

    auto alloc_bytes = workspace_bytes[index];
    total_workspace_bytes += alloc_bytes;
  }

  // Allocate the workspace
  Workspace ws(total_workspace_bytes);

  std::vector<void *> workspace;
116
117
  workspace.push_back(minibatch_mean.contiguous().DATA_PTR<float>());
  workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR<float>());
jjsjann123's avatar
jjsjann123 committed
118
119
120

  auto stream = at::cuda::getCurrentCUDAStream().stream();
  const int retired_cta_bytes = workspace_bytes[2];
121
  void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
Evgeni Krimer's avatar
Evgeni Krimer committed
122
  assert(ret_cta.size(0)>=retired_cta_bytes);
jjsjann123's avatar
jjsjann123 committed
123
124
125
126
127
128
129
130
131
132
  workspace.push_back(retired_ctas);

  for (auto index = 3; index < workspace_bytes.size(); ++index) {
    void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];
    workspace.push_back(ptr);
  }

  bn->setWorkspacePointers(workspace, workspace_bytes);

  // Don't fuse in ReLU for now at least
Evgeni Krimer's avatar
Evgeni Krimer committed
133
  bn->fwd(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
jjsjann123's avatar
jjsjann123 committed
134
135
136
137
138
139
140
141
142
143

  return y;
}

at::Tensor nhwc_bn_fwd_eval(
                       const at::Tensor& x,
                       const at::Tensor& scale,
                       const at::Tensor& bias,
                       const at::Tensor& running_mean,
                       const at::Tensor& running_inv_var,
Evgeni Krimer's avatar
Evgeni Krimer committed
144
                       const at::Tensor& ret_cta,
jjsjann123's avatar
jjsjann123 committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
                       const int bn_group,
                       const float momentum,
                       const float epsilon,
                       const bool fuse_relu) {

  const int N = x.size(0);
  const int H = x.size(1);
  const int W = x.size(2);
  const int C = x.size(3);

  // Allocate output tensor
  at::Tensor y = at::empty({N, H, W, C}, x.options());

  // Create wrapper
  NhwcBatchNorm *bn = new NhwcBatchNorm();

161
162
  bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
  bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
jjsjann123's avatar
jjsjann123 committed
163
164
165
166

  bn->setConstants(momentum, epsilon);

  // set pointers within the wrapper
167
  bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(),
jjsjann123's avatar
jjsjann123 committed
168
                             nullptr,
mcarilli's avatar
mcarilli committed
169
                             y.DATA_PTR<at::Half>(),
jjsjann123's avatar
jjsjann123 committed
170
171
                             nullptr);

172
173
174
175
  bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
                         bias.contiguous().DATA_PTR<float>()}, {nullptr, nullptr});
  bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
                            running_inv_var.contiguous().DATA_PTR<float>()});
jjsjann123's avatar
jjsjann123 committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

  // deal with workspace(s)
  auto workspace_bytes = bn->numWorkspaceBytes();
  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
  // an allocated workspace for the others
  size_t total_workspace_bytes = 0;
  std::vector<size_t> workspace_offsets;

  for (auto index = 3; index < workspace_bytes.size(); ++index) {
    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
    workspace_offsets.push_back(total_workspace_bytes);

    auto alloc_bytes = workspace_bytes[index];
    total_workspace_bytes += alloc_bytes;
  }

  // Allocate the workspace
  Workspace ws(total_workspace_bytes);

  std::vector<void *> workspace;
  workspace.push_back(nullptr);
  workspace.push_back(nullptr);

  auto stream = at::cuda::getCurrentCUDAStream().stream();
  const int retired_cta_bytes = workspace_bytes[2];
201
  void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
Evgeni Krimer's avatar
Evgeni Krimer committed
202
  assert(ret_cta.size(0)>=retired_cta_bytes);
jjsjann123's avatar
jjsjann123 committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
  workspace.push_back(retired_ctas);

  for (auto index = 3; index < workspace_bytes.size(); ++index) {
    void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];
    workspace.push_back(ptr);
  }

  bn->setWorkspacePointers(workspace, workspace_bytes);

  // Don't fuse in ReLU for now at least
  bn->fwdInference(stream, fuse_relu);

  return y;

}

std::vector<at::Tensor> nhwc_bn_bwd(
                       const at::Tensor& x,
                       const at::Tensor& dy,
                       const at::Tensor& scale,
                       const at::Tensor& bias,
                       const at::Tensor& running_mean,
                       const at::Tensor& running_inv_var,
                       const at::Tensor& minibatch_mean,
                       const at::Tensor& minibatch_inv_var,
Evgeni Krimer's avatar
Evgeni Krimer committed
228
                       const at::Tensor& ret_cta,
jjsjann123's avatar
jjsjann123 committed
229
230
231
232
233
234
                       const float momentum,
                       const float epsilon,
                       const bool fuse_relu,
                       void * my_data,
                       void * pair_data, 
                       void * pair_data2, 
Evgeni Krimer's avatar
Evgeni Krimer committed
235
                       void * pair_data3, 
jjsjann123's avatar
jjsjann123 committed
236
237
                       const int bn_group,
                       const at::Tensor& magic_tensor,
Evgeni Krimer's avatar
Evgeni Krimer committed
238
239
240
                       const int occupancy,
                       const int grid_dim_x,
                       const bool coop) {
jjsjann123's avatar
jjsjann123 committed
241
242
243
244
245
246
247
  // shape
  const int N = x.size(0);
  const int H = x.size(1);
  const int W = x.size(2);
  const int C = x.size(3);

  // generating new magic number and use that for sync
mcarilli's avatar
mcarilli committed
248
  int* magic = magic_tensor.DATA_PTR<int>();
jjsjann123's avatar
jjsjann123 committed
249
250
251
252
253
254
255
256
257
258
259
260
261
  *magic = (*magic + 1) & 0xff;

  // outputs
  at::Tensor x_grad, scale_grad, bias_grad;

  // Allocate outputs
  x_grad = at::empty_like(x);
  scale_grad = at::empty_like(scale);
  bias_grad = at::empty_like(bias);

  // Create wrapper
  NhwcBatchNorm *bn = new NhwcBatchNorm();

262
263
  bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
  bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
jjsjann123's avatar
jjsjann123 committed
264
265
266
267

  bn->setConstants(momentum, epsilon);

  // set pointers within the wrapper
268
  bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(),
mcarilli's avatar
mcarilli committed
269
                             x_grad.DATA_PTR<at::Half>(),
jjsjann123's avatar
jjsjann123 committed
270
                             nullptr,
271
                             dy.contiguous().DATA_PTR<at::Half>());
jjsjann123's avatar
jjsjann123 committed
272

273
274
275
276
277
278
  bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
                         bias.contiguous().DATA_PTR<float>()},
                        {scale_grad.DATA_PTR<float>(),
                         bias_grad.DATA_PTR<float>()});
  bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
                            running_inv_var.contiguous().DATA_PTR<float>()});
jjsjann123's avatar
jjsjann123 committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

  // deal with workspace(s)
  auto workspace_bytes = bn->numWorkspaceBytes();
  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
  // an allocated workspace for the others
  size_t total_workspace_bytes = 0;
  std::vector<size_t> workspace_offsets;

  for (auto index = 3; index < workspace_bytes.size(); ++index) {
    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
    workspace_offsets.push_back(total_workspace_bytes);

    auto alloc_bytes = workspace_bytes[index];
    total_workspace_bytes += alloc_bytes;
  }

  // Allocate the workspace
  Workspace ws(total_workspace_bytes);

  std::vector<void *> workspace;
299
300
  workspace.push_back(minibatch_mean.contiguous().DATA_PTR<float>());
  workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR<float>());
jjsjann123's avatar
jjsjann123 committed
301
302
303

  auto stream = at::cuda::getCurrentCUDAStream().stream();
  const int retired_cta_bytes = workspace_bytes[2];
304
  void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
Evgeni Krimer's avatar
Evgeni Krimer committed
305
  assert(ret_cta.size(0)>=retired_cta_bytes);
jjsjann123's avatar
jjsjann123 committed
306
307
308
309
310
311
312
313
314
  workspace.push_back(retired_ctas);

  for (auto index = 3; index < workspace_bytes.size(); ++index) {
    void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];
    workspace.push_back(ptr);
  }

  bn->setWorkspacePointers(workspace, workspace_bytes);

Evgeni Krimer's avatar
Evgeni Krimer committed
315
  bn->dgrad(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
jjsjann123's avatar
jjsjann123 committed
316
317
318

  return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad};
}
Evgeni Krimer's avatar
Evgeni Krimer committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336

int nhwc_bn_fwd_occupancy() {
    int device_id=-1;
    cudaGetDevice(&device_id);

    //max occupancy supported by the code is 2
    return NhwcBatchNorm::smem_driven_fwd_occupancy(device_id, 2);
}

int nhwc_bn_bwd_occupancy() {
    int device_id=-1;
    cudaGetDevice(&device_id);
    
    //max occupancy supported by the code is 2
    return NhwcBatchNorm::smem_driven_bwd_occupancy(device_id, 2);
}