spline_conv_gpu.py 16.8 KB
Newer Older
1
import torch
2
from torch.autograd import Function, Variable
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

from ....utils.cuda import (cuda_num_threads, Stream, load_kernel, kernel_loop,
                            get_blocks)

_edgewise_spline_weighting_forward_kernel = kernel_loop + '''
extern "C"
__global__ void edgewise_spline_weighting_forward_kernel(
const ${Dtype}* input, const ${Dtype}* weight, ${Dtype}* output,
const ${Dtype}* amount, const long* index, int num_threads) {

  CUDA_KERNEL_LOOP(idx, num_threads) {

    const int e_idx = idx / ${M_out};
    const int m_out_idx = idx % ${M_out};

    ${Dtype} result = 0.0;
    ${Dtype} w;
    ${Dtype} f;
    int k;
    ${Dtype} b;
    long c;
    long w_idx;

    for (int k_idx = 0; k_idx < ${k_max}; k_idx++) {
      k = e_idx * ${k_max} + k_idx;
      b = amount[k];
      c = index[k];

      for (int m_in_idx = 0; m_in_idx < ${M_in}; m_in_idx++) {
        w_idx = c * ${M_out} * ${M_in} +
                m_in_idx * ${M_out} +
                m_out_idx;

        w = weight[w_idx];
        f = input[e_idx * ${M_in} + m_in_idx];

        result += b * w * f;
      }
    }

    output[idx] = result;
  }
}
'''

_edgewise_spline_weighting_backward_kernel = kernel_loop + '''
extern "C"
__global__ void edgewise_spline_weighting_backward_kernel(
const ${Dtype}* grad_output, ${Dtype}* grad_input, ${Dtype}* grad_weight,
const ${Dtype}* input, const ${Dtype}* weight, const ${Dtype}* amount,
const long* index, int num_threads) {

  CUDA_KERNEL_LOOP(idx, num_threads) {

    const int e_idx = idx / ${M_out};
    const int m_out_idx = idx % ${M_out};

    ${Dtype} w;
    ${Dtype} g;
    ${Dtype} f;
    ${Dtype} w_grad;
    int k;
    ${Dtype} b;
    long c;
    long w_idx;

    for (int k_idx = 0; k_idx < ${k_max}; k_idx++) {
      k = e_idx * ${k_max} + k_idx;
      b = amount[k];
      c = index[k];

      for (int m_in_idx = 0; m_in_idx < ${M_in}; m_in_idx++) {
        w_idx = c * ${M_out} * ${M_in} +
                m_in_idx * ${M_out} +
                m_out_idx;

        w = weight[w_idx];

        // Calculate input gradient.
        g = grad_output[e_idx * ${M_out} + m_out_idx];
        atomicAdd(&(grad_input[e_idx * ${M_in} + m_in_idx]), b * w * g);
        // This is inefficient: `reduce_sum` shouldn't be done like this.
        // Looping over `M_out` would be better to avoid the `atomicAdd`.

        // Calculate weight gradient.
        f = input[e_idx * ${M_in} + m_in_idx];
        w_grad = f * b * g;
        atomicAdd(&(grad_weight[w_idx]), w_grad);
        // Not so efficient either, but not avoidable.
      }
    }
  }
}
'''

_edgewise_spline_weighting_backward_kernel_bp2adj = kernel_loop + '''
extern "C"
__global__ void edgewise_spline_weighting_backward_kernel(
const ${Dtype}* grad_output, ${Dtype}* grad_input, ${Dtype}* grad_weight, 
${Dtype}* grad_amount, const ${Dtype}* input, const ${Dtype}* weight, 
const ${Dtype}* amount, const long* index, int num_threads) {

  CUDA_KERNEL_LOOP(idx, num_threads) {

    const int e_idx = idx / ${M_out};
    const int m_out_idx = idx % ${M_out};

    ${Dtype} w;
    ${Dtype} g;
    ${Dtype} f;
    ${Dtype} w_grad;
    int k;
    ${Dtype} b;
    long c;
    long w_idx;

    for (int k_idx = 0; k_idx < ${k_max}; k_idx++) {
      k = e_idx * ${k_max} + k_idx;
      b = amount[k];
      c = index[k];
123
      ${Dtype} adj_g = 0.0;
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

      for (int m_in_idx = 0; m_in_idx < ${M_in}; m_in_idx++) {
        w_idx = c * ${M_out} * ${M_in} +
                m_in_idx * ${M_out} +
                m_out_idx;

        w = weight[w_idx];

        // Calculate input gradient.
        g = grad_output[e_idx * ${M_out} + m_out_idx];
        atomicAdd(&(grad_input[e_idx * ${M_in} + m_in_idx]), b * w * g);
        // This is inefficient: `reduce_sum` shouldn't be done like this.
        // Looping over `M_out` would be better to avoid the `atomicAdd`.

        // Calculate weight gradient.
        f = input[e_idx * ${M_in} + m_in_idx];
        w_grad = f * b * g;
        atomicAdd(&(grad_weight[w_idx]), w_grad);
        // Not so efficient either, but not avoidable.

        // Calculate B-spline basis tensor product gradient
        adj_g += g * f * w;
      }
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
147
      atomicAdd(&(grad_amount[e_idx*${k_max} + k_idx]), adj_g);
148
149
150
151
152
153
    }
  }
}
'''


Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
154
def get_weighting_forward_kernel(M_in, M_out, k_max, dtype='float'):
155
    cuda_tensor = torch.FloatTensor([1]).cuda()
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
156
    kernel = _edgewise_spline_weighting_forward_kernel
157
158
159
160
    with torch.cuda.device_of(cuda_tensor):
        f_fw = load_kernel(
            'edgewise_spline_weighting_forward_kernel',
            kernel,
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
161
            Dtype=dtype,
162
163
164
165
166
167
            M_in=M_in,
            M_out=M_out,
            k_max=k_max)
    return f_fw


Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
168
169
def get_weighting_backward_kernel(M_in, M_out, k_max, K, bp_to_adj=False,
                                  dtype='float'):
170
    cuda_tensor = torch.FloatTensor([1]).cuda()
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
171
    if bp_to_adj:
172
173
174
175
176
177
178
        kernel = _edgewise_spline_weighting_backward_kernel_bp2adj
    else:
        kernel = _edgewise_spline_weighting_backward_kernel
    with torch.cuda.device_of(cuda_tensor):
        f_bw = load_kernel(
            'edgewise_spline_weighting_backward_kernel',
            kernel,
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
179
            Dtype=dtype,
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
            M_in=M_in,
            M_out=M_out,
            k_max=k_max,
            K=K)
    return f_bw


_spline_kernel_linear = kernel_loop + '''
extern "C"
__global__ void spline_kernel(
const ${Dtype}* input, ${Dtype}* amount, long* index,
const long* kernel_size, const long* is_open_spline, int num_threads) {

  CUDA_KERNEL_LOOP(idx, num_threads) {

    const int e_idx = idx / ${k_max};
    int k_idx = idx % ${k_max};

Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
198
    int K = ${K};
199
200
201
202
203
204
205
206
207
    int k_idx_mod;
    int bot;
    int top;
    ${Dtype} value;
    ${Dtype} frac;
    ${Dtype} a = 1.0;
    long i = 0;

    for (int d_idx = 0; d_idx < ${dim}; d_idx++) {
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
208
      K/=kernel_size[d_idx];
209
210
211
212
      k_idx_mod = k_idx % 2;
      k_idx >>= 1;

      value = input[e_idx * ${dim} + d_idx];
213

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
      value *= kernel_size[d_idx] - is_open_spline[d_idx];
        
      frac = value - floor(value);

      a *= (1 - k_idx_mod) * (1 - frac) + k_idx_mod * frac;

      bot = int(floor(value));
      top = (bot + 1) % kernel_size[d_idx];
      bot %= kernel_size[d_idx];
      i += ((1 - k_idx_mod) * bot + k_idx_mod * top) * K;
    }

    amount[idx] = a;
    index[idx] = i;
  }
}
'''

_spline_kernel_quadratic = kernel_loop + '''
extern "C"
__global__ void spline_kernel(
const ${Dtype}* input, ${Dtype}* amount, long* index,
const long* kernel_size, const long* is_open_spline, int num_threads) {

  CUDA_KERNEL_LOOP(idx, num_threads) {

    const int e_idx = idx / ${k_max};
    int k_idx = idx % ${k_max};

    int K = ${K};
    int k_idx_mod;
    int pos;
    ${Dtype} value;
    ${Dtype} frac;
    ${Dtype} a = 1.0;
    long i = 0;

    for (int d_idx = 0; d_idx < ${dim}; d_idx++) {

      K /= kernel_size[d_idx];

      k_idx_mod = k_idx % 3;
      k_idx /= 3;

      value = input[e_idx * ${dim} + d_idx] *
              (kernel_size[d_idx] - (2 * is_open_spline[d_idx]));

      frac = value - floor(value);

      if (k_idx_mod == 0) a *= 0.5 * (1- frac) * (1-frac);
      else if (k_idx_mod == 1) a *= -frac * frac + frac + 0.5;
      else a *= 0.5 * frac * frac;

      pos = int(floor(value)) + k_idx_mod;
      pos %= kernel_size[d_idx];

      i += pos * K;
    }
    amount[idx] = a;
    index[idx] = i;
  }
}
'''

_spline_kernel_cubic = kernel_loop + '''
extern "C"
__global__ void spline_kernel(
const ${Dtype}* input, ${Dtype}* amount, long* index,
const long* kernel_size, const long* is_open_spline, int num_threads) {

  CUDA_KERNEL_LOOP(idx, num_threads}) {

    const int e_idx = idx / ${k_max};
    int k_idx = idx % ${k_max};

    int K = ${K};
    int k_idx_mod;
    int pos;
    ${Dtype} value;
    ${Dtype} frac;
    ${Dtype} a = 1.0;
    long i = 0;

    for (int d_idx = 0; d_idx < ${dim}; d_idx++) {

      K /= kernel_size[d_idx];

      k_idx_mod = k_idx % 4;
      k_idx /= 4;

      value = input[e_idx * ${dim} + d_idx] *
              (kernel_size[d_idx] - (3 * is_open_spline[d_idx]));

      frac = value - floor(value);

      if (k_idx_mod == 0) a *= (1 - frac) * (1 - frac) * (1 - frac) / 6.0;
      else if (k_idx_mod == 1)
        a *= (3 * frac * frac * frac - 6 * frac * frac + 4) / 6.0;
      else if (k_idx_mod == 2)
        a *= (-3 * frac * frac * frac + 3 * frac * frac + 3 * frac + 1) / 6.0;
      else a *= frac * frac * frac / 6.0;

      pos = int(floor(value)) + k_idx_mod;
      pos %= kernel_size[d_idx];

      i += pos * K;
    }
    amount[idx] = a;
    index[idx] = i;
  }
}
'''

_spline_kernel_linear_backward = kernel_loop + '''
extern "C"
__global__ void spline_kernel(
const ${Dtype}* input, const ${Dtype}* grad_amount, ${Dtype}* amount, 
${Dtype}* grad_adj, const long* kernel_size, const long* is_open_spline, 
int num_threads) {

  CUDA_KERNEL_LOOP(idx, num_threads) {

    const int e_idx = idx / ${dim};
    int d_idx = idx % ${dim};

    int k_idx_mod;
    ${Dtype} value;
    ${Dtype} frac;
    ${Dtype} grad_out = 0.0;
    
344
    int quotient = (int)pow(2.0,(double)d_idx);
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
345
346
347
    value = input[e_idx * ${dim} + d_idx];
    value *= kernel_size[d_idx] - is_open_spline[d_idx];
    frac = value - floor(value);
348
349
350
351
352
    
    for (int k_idx = 0; k_idx < ${k_max}; k_idx++) {
      
      k_idx_mod = (k_idx/quotient) % 2;

353
354
      ${Dtype} residual = (1 - k_idx_mod) * (frac - 1) + k_idx_mod * frac;
      int a_idx = e_idx*${k_max} + k_idx;
355
      grad_out += grad_amount[a_idx]*amount[a_idx]/residual;
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
356
      
357
    }
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
358
    grad_adj[idx] = grad_out*(kernel_size[d_idx] - is_open_spline[d_idx]);
359
360
  }
}
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376



/*
      ${Dtype} a = -(1 - k_idx_mod) + k_idx_mod;
      for (int d_it = 0; d_it < ${dim}; d_it++) {
        if(d_it!=d_idx)
        {
          value = input[e_idx * ${dim} + d_it];
          value *= kernel_size[d_it] - is_open_spline[d_it];
          frac = value - floor(value);
          a *= (1 - k_idx_mod) * (1 - frac) + k_idx_mod * frac;
        }
      } 
      grad_out += a*grad_amount[a_idx];
      */
377
378
379
'''


Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
380
def get_basis_kernel(k_max, K, dim, degree, dtype='float'):
381
382
383
384
385
386
387
388
389
390
391
392
    if degree == 3:
        _spline_kernel = _spline_kernel_cubic
    elif degree == 2:
        _spline_kernel = _spline_kernel_quadratic
    else:
        _spline_kernel = _spline_kernel_linear

    cuda_tensor = torch.FloatTensor([1]).cuda()
    with torch.cuda.device_of(cuda_tensor):
        f = load_kernel(
            'spline_kernel',
            _spline_kernel,
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
393
            Dtype=dtype,
394
395
396
397
398
399
            k_max=k_max,
            dim=dim,
            K=K)
    return f


Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
400
def get_basis_backward_kernel(k_max, K, dim, degree, dtype='float'):
401
    if degree == 3:
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
402
        raise NotImplementedError
403
    elif degree == 2:
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
404
        raise NotImplementedError
405
406
407
408
409
410
411
412
    else:
        _spline_kernel = _spline_kernel_linear_backward

    cuda_tensor = torch.FloatTensor([1]).cuda()
    with torch.cuda.device_of(cuda_tensor):
        f = load_kernel(
            'spline_kernel',
            _spline_kernel,
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
413
            Dtype=dtype,
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
            k_max=k_max,
            dim=dim,
            K=K)
    return f


class SplineConvGPU(Function):
    def __init__(self, kernel_size, is_open_spline, K, degree,
                               basis_kernel, basis_backward_kernel,
                               weighting_kernel, weighting_backward_kernel,
                               bp_to_adj=False):
        super(SplineConvGPU, self).__init__()
        self.degree = degree
        self.f_weighting_fw = weighting_kernel
        self.f_weighting_bw = weighting_backward_kernel
        self.kernel_size = kernel_size
        self.is_open_spline = is_open_spline
        self.f_basis_fw = basis_kernel
        self.f_basis_bw = basis_backward_kernel
        self.bp_to_adj = bp_to_adj

    def forward(self, input, weight, adj_values):
        assert input.is_cuda and weight.is_cuda
        self.K, self.M_in, self.M_out = weight.size()

        # Compute B-spline basis tensor products
        adj_values = adj_values.unsqueeze(1) if len(adj_values.size()) < 2 \
            else adj_values
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
442
443
444
445
446
447
448

        if self.bp_to_adj:
            self.save_for_backward(input, weight, adj_values)
            #adj_values = torch.clamp(adj_values,min=0.0,max=1.0)
        else:
            self.save_for_backward(input, weight)

449
        num_edges, dim = adj_values.size()
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
450
        k_max = (self.degree+1) ** dim
451
452
453
454
455
456
457
458
459
460
461
462
        amount = adj_values.new(num_edges, k_max)
        index = adj_values.new(num_edges, k_max).long()
        num_threads = amount.numel()
        with torch.cuda.device_of(input):
            self.f_basis_fw(
                block=(cuda_num_threads, 1, 1),
                grid=(get_blocks(num_threads), 1, 1),
                args=[
                    adj_values.data_ptr(),
                    amount.data_ptr(),
                    index.data_ptr(),
                    self.kernel_size.data_ptr(),
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
463
464
                    self.is_open_spline.data_ptr(),
                    num_threads
465
466
467
468
469
470
                ],
                stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))

        # Weight features
        output = input.new(input.size(0), self.M_out)

Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
471
        num_threads = output.numel()
472
473
474
475
476
477
478
479
480
        with torch.cuda.device_of(input):
            self.f_weighting_fw(
                block=(cuda_num_threads, 1, 1),
                grid=(get_blocks(num_threads), 1, 1),
                args=[
                    input.data_ptr(),
                    weight.data_ptr(),
                    output.data_ptr(),
                    amount.data_ptr(),
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
481
482
                    index.data_ptr(),
                    num_threads
483
484
485
                ],
                stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))

486
487
        self.amount = amount
        self.index = index
488
489
490
491
492
493
494
495
496

        return output

    def backward(self, grad_output):
        grad_input = grad_output.new(grad_output.size(0), self.M_in).fill_(0)
        grad_weight = grad_output.new(self.K, self.M_in, self.M_out).fill_(0)
        num_threads = grad_output.numel()

        if self.bp_to_adj:
497
            input, weight, adj_values = self.saved_tensors
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
498
            #adj_values = torch.clamp(adj_values,min=0.0,max=1.0)
499
500
            amount = self.amount
            index = self.index
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
            grad_amount = grad_output.new(amount.size(0),
                                          amount.size(1)).fill_(0)
            with torch.cuda.device_of(grad_output):
                self.f_weighting_bw(
                    block=(cuda_num_threads, 1, 1),
                    grid=(get_blocks(num_threads), 1, 1),
                    args=[
                        grad_output.data_ptr(),
                        grad_input.data_ptr(),
                        grad_weight.data_ptr(),
                        grad_amount.data_ptr(),
                        input.data_ptr(),
                        weight.data_ptr(),
                        amount.data_ptr(),
                        index.data_ptr(), num_threads
                    ],
                    stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))

            grad_adj = grad_amount.new(grad_amount.size(0),
                                       self.kernel_size.size(0)).fill_(0)
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
521

522
            num_threads = grad_adj.numel()
523
524
525
526
527
528
529
530
531
532
533

            with torch.cuda.device_of(grad_amount):
                self.f_basis_bw(
                    block=(cuda_num_threads, 1, 1),
                    grid=(get_blocks(num_threads), 1, 1),
                    args=[
                        adj_values.data_ptr(),
                        grad_amount.data_ptr(),
                        amount.data_ptr(),
                        grad_adj.data_ptr(),
                        self.kernel_size.data_ptr(),
534
535
                        self.is_open_spline.data_ptr(),
                        num_threads
536
537
538
                    ],
                    stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))

Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
539
540
541
542
            #print('grad_input:',grad_input.min(), grad_input.max())
            #print('grad_weight:',grad_weight[:,:,-1].min(), grad_weight[:,:,-1].max())
            #print('grad_amount:',grad_amount.min(), grad_amount.max())
            #print('grad_adj:',grad_adj.min(), grad_adj.max())
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
543

Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
544
            return grad_input, grad_weight, grad_adj
545
546

        else:
547
548
549
            input, weight = self.saved_tensors
            amount = self.amount
            index = self.index
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
550
551
552
            grad_amount = grad_output.new(amount.size(0),
                                          amount.size(1)).fill_(0)

553
554
555
556
557
558
559
560
            with torch.cuda.device_of(grad_output):
                self.f_weighting_bw(
                    block=(cuda_num_threads, 1, 1),
                    grid=(get_blocks(num_threads), 1, 1),
                    args=[
                        grad_output.data_ptr(),
                        grad_input.data_ptr(),
                        grad_weight.data_ptr(),
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
561
                        grad_amount.data_ptr(),
562
563
564
565
566
567
568
                        input.data_ptr(),
                        weight.data_ptr(),
                        amount.data_ptr(),
                        index.data_ptr(), num_threads
                    ],
                    stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))

Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
569

570
            return grad_input, grad_weight, None