"platforms/hip/include/HipFFT3D.h" did not exist on "7901ce206813ab921cad7a86c1fb7021f842d5d2"
common.h 7.04 KB
Newer Older
wangkx1's avatar
init  
wangkx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#ifndef FMSDACOMMON
#define FMSDACOMMON
#include <algorithm>
#include <cstdio>
#include <cstring>

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THCAtomics.cuh>

#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#ifdef _WIN32
#define uint unsigned int
#endif

constexpr int kWarpSize = 32;
#define opmath_t at::opmath_type<scalar_t>

inline int GET_BLOCKS(const int N, const int num_threads) {
  return (N + num_threads - 1) / num_threads;
}

#define CUDA_KERNEL_LOOP(i, n)                                                 \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n);                 \
       i += blockDim.x * gridDim.x)

inline bool check_backward_warpp(int d_stride, int D){
  int n_group_threads = D / d_stride;
  return (n_group_threads <= kWarpSize) && (kWarpSize % n_group_threads == 0);
}

template <typename scalar_t, typename transfer_t, int c_per_thread>
__device__ void ms_deform_attn_im2col_bilinear(
    opmath_t out_reg_array[], const scalar_t *&p_value, const int &height,
    const int &width, const opmath_t &h_px, const opmath_t &w_px,
    const opmath_t &attn, const int &w_stride, const int &base_ptr) {

  const int h_low = floor(h_px);
  const int w_low = floor(w_px);
  const int h_high = h_low + 1;
  const int w_high = w_low + 1;
  const opmath_t lh = h_px - h_low;
  const opmath_t lw = w_px - w_low;
  const opmath_t hh = 1 - lh;
  const opmath_t hw = 1 - lw;

  const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;

  const int h_stride = width * w_stride;
  const int h_low_ptr_offset = h_low * h_stride;
  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
  const int w_low_ptr_offset = w_low * w_stride;
  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;

  int idx1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
  int idx2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
  int idx3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
  int idx4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;

  scalar_t v1_array[c_per_thread] = {0.};
  scalar_t v2_array[c_per_thread] = {0.};
  scalar_t v3_array[c_per_thread] = {0.};
  scalar_t v4_array[c_per_thread] = {0.};

  if (h_low >= 0 && w_low >= 0) {
    auto p1 = p_value + idx1;
    *(transfer_t *)(v1_array) = *(transfer_t *)(p1);
  }

  if (h_low >= 0 && w_high < width) {
    auto p2 = p_value + idx2;
    *(transfer_t *)(v2_array) = *(transfer_t *)(p2);
  }
  if (h_high < height && w_low >= 0) {
    auto p3 = p_value + idx3;
    *(transfer_t *)(v3_array) = *(transfer_t *)(p3);
  }
  if (h_high < height && w_high < width) {
    auto p4 = p_value + idx4;
    *(transfer_t *)(v4_array) = *(transfer_t *)(p4);
  }
#pragma unroll
  for (int i = 0; i < c_per_thread; i++) {
    out_reg_array[i] +=
        (opmath_t)attn *
        (w1 * (opmath_t)v1_array[i] + w2 * (opmath_t)v2_array[i] +
         w3 * (opmath_t)v3_array[i] + w4 * (opmath_t)v4_array[i]);
  }
}

template <typename scalar_t, typename transfer_t, int c_per_thread>
__device__ void ms_deform_attn_col2im_bilinear(
    const scalar_t *&p_value, const int &height, const int &width,
    const opmath_t &h_px, const opmath_t &w_px, const opmath_t &attn,
    const int &w_stride, const int &base_ptr, const opmath_t offset_scale_h,
    const opmath_t offset_scale_w, const scalar_t *&top_grad,
    opmath_t *&grad_im, opmath_t *grad_offset) {

  const int h_low = floor(h_px);
  const int w_low = floor(w_px);
  const int h_high = h_low + 1;
  const int w_high = w_low + 1;
  const opmath_t lh = h_px - h_low;
  const opmath_t lw = w_px - w_low;
  const opmath_t hh = 1 - lh;
  const opmath_t hw = 1 - lw;

  const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;

  scalar_t _top_grad_array[c_per_thread] = {0.};
  *(transfer_t *)(_top_grad_array) = *(transfer_t *)(top_grad);

  opmath_t top_grad_array[c_per_thread] = {0.};
  for (int i = 0; i < c_per_thread; ++i) {
    top_grad_array[i] = (opmath_t)(_top_grad_array[i]);
  }

  const int h_stride = width * w_stride;
  const int h_low_ptr_offset = h_low * h_stride;
  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
  const int w_low_ptr_offset = w_low * w_stride;
  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;

  int idx1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
  int idx2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
  int idx3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
  int idx4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;

  scalar_t v1_array[c_per_thread] = {0.};
  scalar_t v2_array[c_per_thread] = {0.};
  scalar_t v3_array[c_per_thread] = {0.};
  scalar_t v4_array[c_per_thread] = {0.};

  opmath_t grad_h_weight[c_per_thread] = {0.};
  opmath_t grad_w_weight[c_per_thread] = {0.};

  if (h_low >= 0 && w_low >= 0) {
    auto p1 = p_value + idx1;
    *(transfer_t *)(v1_array) = *(transfer_t *)(p1);
#pragma unroll
    for (int i = 0; i < c_per_thread; ++i) {
      grad_h_weight[i] -= hw * v1_array[i];
      grad_w_weight[i] -= hh * v1_array[i];
      atomicAdd(grad_im + idx1 + i, top_grad_array[i] * attn * w1);
    }
  }

  if (h_low >= 0 && w_high < width) {
    auto p2 = p_value + idx2;
    *(transfer_t *)(v2_array) = *(transfer_t *)(p2);
#pragma unroll
    for (int i = 0; i < c_per_thread; ++i) {
      grad_h_weight[i] -= lw * v2_array[i];
      grad_w_weight[i] += hh * v2_array[i];
      atomicAdd(grad_im + idx2 + i, top_grad_array[i] * attn * w2);
    }
  }
  if (h_high < height && w_low >= 0) {
    auto p3 = p_value + idx3;
    *(transfer_t *)(v3_array) = *(transfer_t *)(p3);
#pragma unroll
    for (int i = 0; i < c_per_thread; ++i) {
      grad_h_weight[i] += hw * v3_array[i];
      grad_w_weight[i] -= lh * v3_array[i];
      atomicAdd(grad_im + idx3 + i, top_grad_array[i] * attn * w3);
    }
  }
  if (h_high < height && w_high < width) {
    auto p4 = p_value + idx4;
    *(transfer_t *)(v4_array) = *(transfer_t *)(p4);
#pragma unroll
    for (int i = 0; i < c_per_thread; ++i) {
      grad_h_weight[i] += lw * v4_array[i];
      grad_w_weight[i] += lh * v4_array[i];
      atomicAdd(grad_im + idx4 + i, top_grad_array[i] * attn * w4);
    }
  }

  opmath_t _grad_offset_x = 0;
  opmath_t _grad_offset_y = 0;
#pragma unroll
  for (int i = 0; i < c_per_thread; ++i) {
    _grad_offset_x +=
        grad_w_weight[i] * top_grad_array[i]; // channel aware term
    _grad_offset_y +=
        grad_h_weight[i] * top_grad_array[i]; // channel aware term
  }
  _grad_offset_x *= (offset_scale_w * attn); // channel shared term
  _grad_offset_y *= (offset_scale_h * attn); // channel shared term

  *grad_offset = _grad_offset_x;
  *(grad_offset + 1) = _grad_offset_y;

  opmath_t current_val;
  opmath_t _grad_offset_z = 0;
#pragma unroll
  for (int i = 0; i < c_per_thread; i++) {
    current_val = (opmath_t)(w1 * v1_array[i] + w2 * v2_array[i] +
                             w3 * v3_array[i] + w4 * v4_array[i]);
    _grad_offset_z += current_val * top_grad_array[i];
  }
  *(grad_offset + 2) = _grad_offset_z;
}



#endif