edgewise_spline_weighting_gpu.py 5.51 KB
Newer Older
rusty1s's avatar
rename  
rusty1s committed
1
2
3
4
5
6
7
8
9
10
import torch
from torch.autograd import Function

from ....utils.cuda import (cuda_num_threads, Stream, Dtype, load_kernel,
                            kernel_loop, get_blocks)

_edgewise_spline_weighting_forward_kernel = kernel_loop + '''
extern "C"
__global__ void edgewise_spline_weighting_forward_kernel(
const ${Dtype}* input, const ${Dtype}* weight, ${Dtype}* output,
11
const ${Dtype}* amount, const long* index, int num_threads) {
rusty1s's avatar
rename  
rusty1s committed
12

13
  CUDA_KERNEL_LOOP(idx, num_threads) {
rusty1s's avatar
rename  
rusty1s committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

    const int e_idx = idx / ${M_out};
    const int m_out_idx = idx % ${M_out};

    ${Dtype} result = 0.0;
    ${Dtype} w;
    ${Dtype} f;
    int k;
    ${Dtype} b;
    long c;
    long w_idx;

    for (int k_idx = 0; k_idx < ${k_max}; k_idx++) {
      k = e_idx * ${k_max} + k_idx;
      b = amount[k];
      c = index[k];

      for (int m_in_idx = 0; m_in_idx < ${M_in}; m_in_idx++) {
        w_idx = c * ${M_out} * ${M_in} +
                m_in_idx * ${M_out} +
                m_out_idx;

        w = weight[w_idx];
        f = input[e_idx * ${M_in} + m_in_idx];

        result += b * w * f;
      }
    }

    output[idx] = result;
  }
}
'''

_edgewise_spline_weighting_backward_kernel = kernel_loop + '''
extern "C"
__global__ void edgewise_spline_weighting_backward_kernel(
const ${Dtype}* grad_output, ${Dtype}* grad_input, ${Dtype}* grad_weight,
const ${Dtype}* input, const ${Dtype}* weight, const ${Dtype}* amount,
53
const long* index, int num_threads) {
rusty1s's avatar
rename  
rusty1s committed
54

55
  CUDA_KERNEL_LOOP(idx, num_threads) {
rusty1s's avatar
rename  
rusty1s committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    const int e_idx = idx / ${M_out};
    const int m_out_idx = idx % ${M_out};

    ${Dtype} w;
    ${Dtype} g;
    ${Dtype} f;
    ${Dtype} w_grad;
    int k;
    ${Dtype} b;
    long c;
    long w_idx;

    for (int k_idx = 0; k_idx < ${k_max}; k_idx++) {
      k = e_idx * ${k_max} + k_idx;
      b = amount[k];
      c = index[k];

      for (int m_in_idx = 0; m_in_idx < ${M_in}; m_in_idx++) {
        w_idx = c * ${M_out} * ${M_in} +
                m_in_idx * ${M_out} +
                m_out_idx;

        w = weight[w_idx];

        // Calculate input gradient.
        g = grad_output[e_idx * ${M_out} + m_out_idx];
        atomicAdd(&(grad_input[e_idx * ${M_in} + m_in_idx]), b * w * g);
        // This is inefficient: `reduce_sum` shouldn't be done like this.
        // Looping over `M_out` would be better to avoid the `atomicAdd`.

        // Calculate weight gradient.
        f = input[e_idx * ${M_in} + m_in_idx];
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
89
        w_grad = f * b * g;
rusty1s's avatar
rename  
rusty1s committed
90
91
92
93
94
95
96
97
        atomicAdd(&(grad_weight[w_idx]), w_grad);
        // Not so efficient either, but not avoidable.
      }
    }
  }
}
'''

98
99

def get_forward_kernel(M_in, M_out, k_max):
100
101
102
103
104
105
106
107
108
109
110
    cuda_tensor = torch.FloatTensor([1]).cuda()
    with torch.cuda.device_of(cuda_tensor):
        f_fw = load_kernel(
            'edgewise_spline_weighting_forward_kernel',
            _edgewise_spline_weighting_forward_kernel,
            Dtype='float',
            M_in=M_in,
            M_out=M_out,
            k_max=k_max)
    return f_fw

111
112

def get_backward_kernel(M_in, M_out, k_max, K):
113
114
115
116
117
118
119
120
121
122
123
124
    cuda_tensor = torch.FloatTensor([1]).cuda()
    with torch.cuda.device_of(cuda_tensor):
        f_bw = load_kernel(
            'edgewise_spline_weighting_backward_kernel',
            _edgewise_spline_weighting_backward_kernel,
            Dtype='float',
            M_in=M_in,
            M_out=M_out,
            k_max=k_max,
            K=K)
    return f_bw

rusty1s's avatar
rename  
rusty1s committed
125
126

class EdgewiseSplineWeightingGPU(Function):
127
    def __init__(self, amount, index, K, M_in, M_out, k_fw, k_bw):
rusty1s's avatar
rename  
rusty1s committed
128
129
130
131
        super(EdgewiseSplineWeightingGPU, self).__init__()
        assert amount.is_cuda and index.is_cuda
        self.amount = amount
        self.index = index
132
133
134
        self.M_in = M_in
        self.M_out = M_out
        self.K = K
135
136
137
        self.f_fw = k_fw
        self.f_bw = k_bw

rusty1s's avatar
rename  
rusty1s committed
138
139
140
141
142
    def forward(self, input, weight):
        assert input.is_cuda and weight.is_cuda

        self.save_for_backward(input, weight)

143
        output = input.new(input.size(0), self.M_out)
rusty1s's avatar
rename  
rusty1s committed
144
145
146
        num_threads = output.numel()

        with torch.cuda.device_of(input):
147
            self.f_fw(block=(cuda_num_threads, 1, 1),
148
149
150
151
152
153
154
155
156
157
158
                      grid=(get_blocks(num_threads), 1, 1),
                      args=[
                          input.data_ptr(),
                          weight.data_ptr(),
                          output.data_ptr(),
                          self.amount.data_ptr(),
                          self.index.data_ptr(),
                          num_threads
                      ],
                      stream=Stream(
                          ptr=torch.cuda.current_stream().cuda_stream))
rusty1s's avatar
rename  
rusty1s committed
159
160
161
162
163
164

        return output

    def backward(self, grad_output):
        input, weight = self.saved_tensors

165
166
        grad_input = grad_output.new(input.size(0), self.M_in).fill_(0)
        grad_weight = grad_output.new(self.K, self.M_in, self.M_out).fill_(0)
rusty1s's avatar
rename  
rusty1s committed
167
168
169
170

        num_threads = grad_output.numel()

        with torch.cuda.device_of(grad_output):
171
            self.f_bw(block=(cuda_num_threads, 1, 1),
172
173
174
175
176
177
178
179
180
181
182
183
184
                      grid=(get_blocks(num_threads), 1, 1),
                      args=[
                          grad_output.data_ptr(),
                          grad_input.data_ptr(),
                          grad_weight.data_ptr(),
                          input.data_ptr(),
                          weight.data_ptr(),
                          self.amount.data_ptr(),
                          self.index.data_ptr(),
                          num_threads
                      ],
                      stream=Stream(
                          ptr=torch.cuda.current_stream().cuda_stream))
rusty1s's avatar
rename  
rusty1s committed
185
186

        return grad_input, grad_weight