msi_kernel.cu 20.5 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/types.h>
#include <cassert>

#include <cuda_math_helper.h>
#include <grid_utils.h>
#include <kernel_utils.h>

using namespace math;

template <typename scalar_t, typename index_t>
__device__ inline typename math::TVec4<scalar_t> msi_sample_bilinear_cubic(
    const TensorInfo<scalar_t, index_t>& input,
    math::TVec3<scalar_t> uvw) {
  typedef typename math::TVec2<scalar_t> scalar2_t;
  typedef typename math::TVec3<scalar_t> scalar3_t;
  typedef typename math::TVec4<scalar_t> scalar4_t;

  index_t inp_N = input.sizes[0];
  index_t inp_H = input.sizes[2];
  index_t inp_W = input.sizes[3];
  index_t inp_sN = input.strides[0];
  index_t inp_sC = input.strides[1];
  index_t inp_sH = input.strides[2];
  index_t inp_sW = input.strides[3];

  int3 size = {(int)inp_W, (int)inp_H, (int)inp_N};

  scalar3_t i_uvw =
      ((uvw + 1.f) * scalar3_t({(float)size.x, (float)size.y, (float)size.z}) - 1.f) / 2.f;
  i_uvw.x = safe_downgrade_to_int_range(clip_coordinates(i_uvw.x, size.x));
  i_uvw.y = safe_downgrade_to_int_range(clip_coordinates(i_uvw.y, size.y));
  i_uvw.z = safe_downgrade_to_int_range(clip_coordinates(i_uvw.z, size.z));

  // get NE, NW, SE, SW pixel values from (x, y)
  index_t ix_nw = static_cast<index_t>(::floor(i_uvw.x));
  index_t iy_nw = static_cast<index_t>(::floor(i_uvw.y));
  index_t iz_nw = static_cast<index_t>(::floor(i_uvw.z));
  index_t ix_ne = ix_nw + 1;
  index_t iy_ne = iy_nw;
  index_t ix_sw = ix_nw;
  index_t iy_sw = iy_nw + 1;
  index_t ix_se = ix_nw + 1;
  index_t iy_se = iy_nw + 1;

  const scalar_t tz = i_uvw.z - iz_nw;

  // get surfaces to each neighbor:
  scalar_t nw = (ix_se - i_uvw.x) * (iy_se - i_uvw.y);
  scalar_t ne = (i_uvw.x - ix_sw) * (iy_sw - i_uvw.y);
  scalar_t sw = (ix_ne - i_uvw.x) * (i_uvw.y - iy_ne);
  scalar_t se = (i_uvw.x - ix_nw) * (i_uvw.y - iy_nw);

  scalar4_t coefficients[4];
#pragma unroll 4
  for (index_t i = 0; i < 4; ++i) {
    scalar_t z = clip_coordinates(iz_nw - 1 + i, size.z);
    int iz = static_cast<int>(z);

    auto inp_ptr_NC = input.data + iz * inp_sN;
    scalar4_t out = {0, 0, 0, 0};

    if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
      auto ptr = inp_ptr_NC + iy_nw * inp_sH + ix_nw * inp_sW;
      out = out + load4(ptr, inp_sC) * nw;
    }
    if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
      auto ptr = inp_ptr_NC + iy_ne * inp_sH + ix_ne * inp_sW;
      out = out + load4(ptr, inp_sC) * ne;
    }
    if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
      auto ptr = inp_ptr_NC + iy_sw * inp_sH + ix_sw * inp_sW;
      out = out + load4(ptr, inp_sC) * sw;
    }
    if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
      auto ptr = inp_ptr_NC + iy_se * inp_sH + ix_se * inp_sW;
      out = out + load4(ptr, inp_sC) * se;
    }
    coefficients[i] = out;
  }
  return cubic_interp1d<scalar_t, 4>(
      coefficients[0], coefficients[1], coefficients[2], coefficients[3], tz);
}

template <typename scalar_t, typename index_t>
__device__ inline void msi_sample_bilinear_cubic_backward(
    const TensorInfo<scalar_t, index_t>& grad_input,
    math::TVec4<scalar_t> grad_output,
    math::TVec3<scalar_t> uvw,
    index_t grad_input_memory_span) {
  typedef typename math::TVec2<scalar_t> scalar2_t;
  typedef typename math::TVec3<scalar_t> scalar3_t;
  typedef typename math::TVec4<scalar_t> scalar4_t;

  index_t gInp_sN = grad_input.strides[0];
  index_t gInp_sC = grad_input.strides[1];
  index_t gInp_sH = grad_input.strides[2];
  index_t gInp_sW = grad_input.strides[3];

  index_t inp_N = grad_input.sizes[0];
  index_t inp_H = grad_input.sizes[2];
  index_t inp_W = grad_input.sizes[3];

  int3 size = {(int)inp_W, (int)inp_H, (int)inp_N};

  scalar3_t i_uvw =
      ((uvw + 1.f) * scalar3_t({(float)size.x, (float)size.y, (float)size.z}) - 1.f) / 2.f;
  i_uvw.x = safe_downgrade_to_int_range(clip_coordinates(i_uvw.x, size.x));
  i_uvw.y = safe_downgrade_to_int_range(clip_coordinates(i_uvw.y, size.y));
  i_uvw.z = safe_downgrade_to_int_range(clip_coordinates(i_uvw.z, size.z));

  // get NE, NW, SE, SW pixel values from (x, y)
  index_t ix_nw = static_cast<index_t>(::floor(i_uvw.x));
  index_t iy_nw = static_cast<index_t>(::floor(i_uvw.y));
  index_t iz_nw = static_cast<index_t>(::floor(i_uvw.z));
  index_t ix_ne = ix_nw + 1;
  index_t iy_ne = iy_nw;
  index_t ix_sw = ix_nw;
  index_t iy_sw = iy_nw + 1;
  index_t ix_se = ix_nw + 1;
  index_t iy_se = iy_nw + 1;

  const scalar_t tz = i_uvw.z - iz_nw;

  // get surfaces to each neighbor:
  scalar_t nw = (ix_se - i_uvw.x) * (iy_se - i_uvw.y);
  scalar_t ne = (i_uvw.x - ix_sw) * (iy_sw - i_uvw.y);
  scalar_t sw = (ix_ne - i_uvw.x) * (i_uvw.y - iy_ne);
  scalar_t se = (i_uvw.x - ix_nw) * (i_uvw.y - iy_nw);

  scalar_t coeffs[4];

  get_cubic_upsampling_coefficients<scalar_t>(coeffs, tz);

#pragma unroll 4
  for (index_t i = 0; i < 4; ++i) {
    scalar_t z = clip_coordinates(iz_nw - 1 + i, size.z);
    int iz = static_cast<int>(z);

    index_t N_offset = iz * gInp_sN;

    // calculate and set grad_input. See Note [Passing pointer and offset to
    // fastAtomicAdd].
    safe_add_2d4(
        grad_input.data,
        gInp_sC,
        iy_nw,
        ix_nw,
        gInp_sH,
        gInp_sW,
        inp_H,
        inp_W,
        nw * grad_output * coeffs[i],
        N_offset,
        grad_input_memory_span);
    safe_add_2d4(
        grad_input.data,
        gInp_sC,
        iy_ne,
        ix_ne,
        gInp_sH,
        gInp_sW,
        inp_H,
        inp_W,
        ne * grad_output * coeffs[i],
        N_offset,
        grad_input_memory_span);
    safe_add_2d4(
        grad_input.data,
        gInp_sC,
        iy_sw,
        ix_sw,
        gInp_sH,
        gInp_sW,
        inp_H,
        inp_W,
        sw * grad_output * coeffs[i],
        N_offset,
        grad_input_memory_span);
    safe_add_2d4(
        grad_input.data,
        gInp_sC,
        iy_se,
        ix_se,
        gInp_sH,
        gInp_sW,
        inp_H,
        inp_W,
        se * grad_output * coeffs[i],
        N_offset,
        grad_input_memory_span);
  }
}

__device__ __host__ __forceinline__ float2 direction_to_equirectangular(float3 d) {
  const float longitude = atan2f(d.z, d.x);
  const float latitude = atan2f(d.y, math::norm(float2{d.x, d.z}));
  constexpr float inv_pi = M_1_PI;

  return float2({longitude, 2 * latitude}) * inv_pi;
}

template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(256)
__global__ void msi_forward_kernel(
    const index_t nthreads,
    TensorInfo<float, index_t> ray_o,
    TensorInfo<float, index_t> ray_d,
    TensorInfo<scalar_t, index_t> texture,
    TensorInfo<scalar_t, index_t> rgba_img,
    int sub_step_count,
    double min_inv_r,
    double max_inv_r,
    double stop_thresh) {
  typedef typename math::TVec4<scalar_t> scalar4_t;
  typedef typename math::TVec3<scalar_t> scalar3_t;

  const int n_layers = texture.sizes[0];
  const int n_steps = n_layers * sub_step_count;

  const index_t ray_o_sN = ray_o.strides[0];
  const index_t ray_o_sC = ray_o.strides[1];

  const index_t ray_d_sN = ray_d.strides[0];
  const index_t ray_d_sC = ray_d.strides[1];

  const index_t rgba_img_sN = rgba_img.strides[0];
  const index_t rgba_img_sC = rgba_img.strides[1];

  CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
    auto rgba_ptr = rgba_img.data + rgba_img_sN * index;

    const float3 r_o = {
        ray_o.data[ray_o_sN * index + ray_o_sC * 0],
        ray_o.data[ray_o_sN * index + ray_o_sC * 1],
        ray_o.data[ray_o_sN * index + ray_o_sC * 2]};

    const float3 r_d = normalize(float3(
        {ray_d.data[ray_d_sN * index + ray_d_sC * 0],
         ray_d.data[ray_d_sN * index + ray_d_sC * 1],
         ray_d.data[ray_d_sN * index + ray_d_sC * 2]}));

    float tc = dot(-r_o, r_d);
    float h2 = dot(r_o, r_o) - tc * tc;

    float step_size = 1.0f / float(n_steps);

    float3 out_v = {0.f, 0.f, 0.f};
    float log_transmit = 0.f;

    for (int i = 0; i < n_steps; ++i) {
      const float a = (float(n_steps - 1 - i) + 0.5f) / float(n_steps);
      const float inv_r = (1.0 - a) * max_inv_r + a * min_inv_r;

      const float r = 1.0f / inv_r;

      float det = r * r - h2;
      if (det < 0.0f)
        continue;

      float t = tc + sqrt(det);
      float3 pos = t * r_d + r_o;

      const float w = 1.f - a * 2.f;

      const float3 uvw = make_float3(direction_to_equirectangular(pos), w);

      auto sample = msi_sample_bilinear_cubic(texture, uvw);

      scalar3_t rgb = {sample.x, sample.y, sample.z};
      float alpha = sample.w;

      if (alpha > 0.0f) {
        const float pcnt = alpha * step_size;
        const float weight = __expf(log_transmit) * (1.f - __expf(-pcnt));
        log_transmit -= pcnt;

        out_v = out_v + weight * math::max(rgb, {0.f, 0.f, 0.f});

        if (__expf(log_transmit) < stop_thresh) {
          log_transmit = -1e3f;
          break;
        }
      }
    }

    rgba_ptr[0 * rgba_img_sC] = out_v.x;
    rgba_ptr[1 * rgba_img_sC] = out_v.y;
    rgba_ptr[2 * rgba_img_sC] = out_v.z;
    rgba_ptr[3 * rgba_img_sC] = log_transmit;
  }
}

template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(256)
__global__ void msi_backward_kernel(
    const index_t nthreads,
    TensorInfo<float, index_t> ray_o,
    TensorInfo<float, index_t> ray_d,
    TensorInfo<scalar_t, index_t> texture,
    TensorInfo<scalar_t, index_t> texture_grad,
    index_t texture_grad_memory_span,
    TensorInfo<scalar_t, index_t> rgba_img,
    TensorInfo<scalar_t, index_t> rgba_img_grad,
    int sub_step_count,
    double min_inv_r,
    double max_inv_r,
    double stop_thresh) {
  typedef typename math::TVec4<scalar_t> scalar4_t;
  typedef typename math::TVec3<scalar_t> scalar3_t;

  const int n_layers = texture.sizes[0];
  const int n_steps = n_layers * sub_step_count;

  const index_t ray_o_sN = ray_o.strides[0];
  const index_t ray_o_sC = ray_o.strides[1];

  const index_t ray_d_sN = ray_d.strides[0];
  const index_t ray_d_sC = ray_d.strides[1];

  const index_t rgba_img_sN = rgba_img.strides[0];
  const index_t rgba_img_sC = rgba_img.strides[1];
  const index_t rgba_img_grad_sN = rgba_img_grad.strides[0];
  const index_t rgba_img_grad_sC = rgba_img_grad.strides[1];

  CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
    auto rgba_ptr = rgba_img.data + rgba_img_sN * index;
    auto rgba_grad_ptr = rgba_img_grad.data + rgba_img_grad_sN * index;

    scalar3_t out_v_grad = {
        rgba_grad_ptr[0 * rgba_img_grad_sC],
        rgba_grad_ptr[1 * rgba_img_grad_sC],
        rgba_grad_ptr[2 * rgba_img_grad_sC]};
    scalar3_t out_v_acc =
        out_v_grad *
        scalar3_t(
            {rgba_ptr[0 * rgba_img_sC], rgba_ptr[1 * rgba_img_sC], rgba_ptr[2 * rgba_img_sC]});

    const float3 r_o = {
        ray_o.data[ray_o_sN * index + ray_o_sC * 0],
        ray_o.data[ray_o_sN * index + ray_o_sC * 1],
        ray_o.data[ray_o_sN * index + ray_o_sC * 2]};

    const float3 r_d = normalize(float3(
        {ray_d.data[ray_d_sN * index + ray_d_sC * 0],
         ray_d.data[ray_d_sN * index + ray_d_sC * 1],
         ray_d.data[ray_d_sN * index + ray_d_sC * 2]}));

    float tc = dot(-r_o, r_d);
    float h2 = dot(r_o, r_o) - tc * tc;

    float step_size = 1.0f / float(n_steps);

    float log_transmit = 0.f;

    for (int i = 0; i < n_steps; ++i) {
      const float a = (float(n_steps - 1 - i) + 0.5f) / float(n_steps);
      const float inv_r = (1.0 - a) * max_inv_r + a * min_inv_r;

      const float r = 1.0f / inv_r;

      float det = r * r - h2;
      if (det < 0.0f)
        continue;

      float t = tc + sqrt(det);
      float3 pos = t * r_d + r_o;

      const float w = 1.f - a * 2.f;

      const float3 uvw = make_float3(direction_to_equirectangular(pos), w);

      auto sample = msi_sample_bilinear_cubic(texture, uvw);

      scalar3_t rgb = {sample.x, sample.y, sample.z};
      float alpha = sample.w;

      if (alpha > 0.0f) {
        const float pcnt = alpha * step_size;
        const float weight = __expf(log_transmit) * (1.f - __expf(-pcnt));
        log_transmit -= pcnt;

        auto rgb_01 = math::max(rgb, {0.f, 0.f, 0.f});
        scalar3_t color_in_01 = scalar3_t(
            {scalar_t(rgb_01.x == rgb.x),
             scalar_t(rgb_01.y == rgb.y),
             scalar_t(rgb_01.z == rgb.z)});

        scalar3_t color_grad = color_in_01 * weight * out_v_grad;

        out_v_acc -= weight * rgb_01 * out_v_grad;

        float alpha_grad =
            sum(rgb_01 * out_v_grad * __expf(-alpha) * __expf(log_transmit) - out_v_acc);

        scalar4_t rgba_grad = make_float4(color_grad, alpha_grad);

        msi_sample_bilinear_cubic_backward(texture_grad, rgba_grad, uvw, texture_grad_memory_span);

        if (__expf(log_transmit) < stop_thresh) {
          log_transmit = -1e3f;
          break;
        }
      }
    }
  }
}

__host__ torch::Tensor msi_forward_cuda(
    const torch::Tensor& ray_o,
    const torch::Tensor& ray_d,
    const torch::Tensor& texture,
    int64_t sub_step_count,
    double min_inv_r,
    double max_inv_r,
    double stop_thresh) {
  TORCH_CHECK(sub_step_count > 0, "msi(): expected step_size > 0, but got ", sub_step_count);
  TORCH_CHECK(
      stop_thresh > 0 && stop_thresh < 1,
      "msi(): expected 0 < stop_thresh < 1, but got ",
      stop_thresh);

  TORCH_CHECK(
      min_inv_r > max_inv_r,
      "msi(): expected min_inv_r to be greater than max_inv_r, but "
      "got min_inv_r:",
      min_inv_r,
      " and max_inv_r: ",
      max_inv_r);

  TORCH_CHECK(
      ray_o.defined() && ray_d.defined() && texture.defined(),
      "msi(): expected all inputs not be undefined, but "
      "ray_o is ",
      ray_o,
      ", ray_d is ",
      ray_d,
      ", texture is ",
      texture);

  auto ray_o_opt = ray_o.options();
  auto ray_d_opt = ray_d.options();
  auto texture_opt = texture.options();

  auto device = ray_o_opt.device();
  auto tex_dtype = texture_opt.dtype();
  auto ray_dtype = ray_o_opt.dtype();

  TORCH_CHECK(
      device.is_cuda(), "msi(): expected inputs to be on CUDA device, but got ray_o on ", device);

  const at::cuda::OptionalCUDAGuard device_guard(device);

  TORCH_CHECK(
      device == ray_o_opt.device() && device == ray_d_opt.device() &&
          device == texture_opt.device(),
      "msi(): expected all inputs to be on same device, but input "
      "ray_o is ",
      ray_o_opt.device(),
      ", ray_d is ",
      ray_d_opt.device(),
      ", texture is ",
      texture_opt.device());

  TORCH_CHECK(
      tex_dtype == torch::kFloat64 || tex_dtype == torch::kFloat32 || tex_dtype == torch::kHalf,
      "msi(): expected texture to be of type Double, Float or "
      "Half, but got type ",
      texture_opt.dtype());

  TORCH_CHECK(
      ray_o_opt.dtype() == torch::kFloat32 && ray_d_opt.dtype() == torch::kFloat32,
      "msi(): expected ray_o and ray_d to be of type Float, but "
      "input ray_o is  ",
      ray_o_opt.dtype(),
      " and ray_d is ",
      ray_d_opt.dtype());

  TORCH_CHECK(
      torch::kStrided == ray_o_opt.layout() && torch::kStrided == ray_d_opt.layout() &&
          torch::kStrided == texture_opt.layout(),
      "msi(): expected all inputs to have torch.strided layout, but "
      "ray_o has ",
      ray_o_opt.layout(),
      ", ray_d has ",
      ray_d_opt.layout(),
      ", texture has ",
      texture_opt.layout());

  TORCH_CHECK(
      ray_o.dim() == 2 && ray_d.dim() == 2 && texture.dim() == 4,
      "msi(): expected ray_o and ray_d to have 2 dimensions, "
      "and texture to have 4 dimension, "
      "but got ray_o with size ",
      ray_o.sizes(),
      ", ray_d with size ",
      ray_d.sizes(),
      ", texture with size ",
      texture.sizes());

  TORCH_CHECK(
      ray_o.size(1) == 3 && ray_d.size(1) == 3 && texture.size(1) == 4,
      "msi(): expected ray_o, ray_d to have size 3 along the dimension 1, "
      " and texture to have size 4 along the dimension 1, "
      "but got ray_o with size ",
      ray_o.sizes(),
      ", ray_d with size ",
      ray_d.sizes(),
      ", texture with size ",
      texture.sizes());

  TORCH_CHECK(
      ray_o.size(0) == ray_d.size(0),
      "msi(): expected ray_o, ray_d to have the same size along "
      "the dimension 0, "
      "but got ray_o with size ",
      ray_o.sizes(),
      ", ray_d with size ",
      ray_d.sizes());

  int N = ray_o.size(0);
  auto rgba_img = torch::empty({N, 4}, texture.options());

  if (N > 0) {
    DISPATCH_FLOAT(texture.scalar_type(), "msi_forward_kernel", [&] {
      if (at::native::canUse32BitIndexMath(ray_o) && at::native::canUse32BitIndexMath(ray_d) &&
          at::native::canUse32BitIndexMath(texture)) {
        typedef int index_type;

        msi_forward_kernel<scalar_t, index_type>
            <<<GET_BLOCKS(N, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
                static_cast<index_type>(N),
                getTensorInfo<float, index_type>(ray_o),
                getTensorInfo<float, index_type>(ray_d),
                getTensorInfo<scalar_t, index_type>(texture),
                getTensorInfo<scalar_t, index_type>(rgba_img),
                (int)sub_step_count,
                min_inv_r,
                max_inv_r,
                stop_thresh);
        C10_CUDA_KERNEL_LAUNCH_CHECK();
      } else {
        typedef int64_t index_type;

        msi_forward_kernel<scalar_t, index_type>
            <<<GET_BLOCKS(N, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
                static_cast<index_type>(N),
                getTensorInfo<float, index_type>(ray_o),
                getTensorInfo<float, index_type>(ray_d),
                getTensorInfo<scalar_t, index_type>(texture),
                getTensorInfo<scalar_t, index_type>(rgba_img),
                (int)sub_step_count,
                min_inv_r,
                max_inv_r,
                stop_thresh);
        C10_CUDA_KERNEL_LAUNCH_CHECK();
      }
    });
  }
  return rgba_img;
}

torch::Tensor msi_backward_cuda(
    const torch::Tensor& rgba_img,
    const torch::Tensor& rgba_img_grad,
    const torch::Tensor& ray_o,
    const torch::Tensor& ray_d,
    const torch::Tensor& texture,
    int64_t sub_step_count,
    double min_inv_r,
    double max_inv_r,
    double stop_thresh) {
  auto ray_o_opt = ray_o.options();
  auto ray_d_opt = ray_d.options();
  auto texture_opt = texture.options();

  auto device = ray_o_opt.device();
  const at::cuda::OptionalCUDAGuard device_guard(device);

  auto tex_dtype = texture_opt.dtype();
  auto ray_dtype = ray_o_opt.dtype();

  int N = ray_o.size(0);
  auto texture_grad = torch::zeros_like(texture);

  if (N > 0) {
    DISPATCH_FLOAT(texture.scalar_type(), "msi_forward_kernel", [&] {
      if (at::native::canUse32BitIndexMath(ray_o) && at::native::canUse32BitIndexMath(ray_d) &&
          at::native::canUse32BitIndexMath(rgba_img) &&
          at::native::canUse32BitIndexMath(rgba_img_grad) &&
          at::native::canUse32BitIndexMath(texture_grad) &&
          at::native::canUse32BitIndexMath(texture)) {
        typedef int index_type;

        index_type texture_grad_memory_span = texture_grad.numel();
        msi_backward_kernel<scalar_t, index_type>
            <<<GET_BLOCKS(N, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
                static_cast<index_type>(N),
                getTensorInfo<float, index_type>(ray_o),
                getTensorInfo<float, index_type>(ray_d),
                getTensorInfo<scalar_t, index_type>(texture),
                getTensorInfo<scalar_t, index_type>(texture_grad),
                texture_grad_memory_span,
                getTensorInfo<scalar_t, index_type>(rgba_img),
                getTensorInfo<scalar_t, index_type>(rgba_img_grad),
                (int)sub_step_count,
                min_inv_r,
                max_inv_r,
                stop_thresh);
        C10_CUDA_KERNEL_LAUNCH_CHECK();
      } else {
        typedef int64_t index_type;

        index_type texture_grad_memory_span = texture_grad.numel();
        msi_backward_kernel<scalar_t, index_type>
            <<<GET_BLOCKS(N, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
                static_cast<index_type>(N),
                getTensorInfo<float, index_type>(ray_o),
                getTensorInfo<float, index_type>(ray_d),
                getTensorInfo<scalar_t, index_type>(texture),
                getTensorInfo<scalar_t, index_type>(texture_grad),
                texture_grad_memory_span,
                getTensorInfo<scalar_t, index_type>(rgba_img),
                getTensorInfo<scalar_t, index_type>(rgba_img_grad),
                (int)sub_step_count,
                min_inv_r,
                max_inv_r,
                stop_thresh);
        C10_CUDA_KERNEL_LAUNCH_CHECK();
      }
    });
  }
  return texture_grad;
}