tin_shift.cpp 1.58 KB
Newer Older
Jintao Lin's avatar
Jintao Lin committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include "parrots_cpp_helper.hpp"

void TINShiftForwardCUDAKernelLauncher(const DArrayLite input,
                                       const DArrayLite shift,
                                       DArrayLite output, cudaStream_t stream);

void TINShiftBackwardCUDAKernelLauncher(const DArrayLite grad_output,
                                        const DArrayLite shift,
                                        DArrayLite grad_input,
                                        cudaStream_t stream);

void tin_shift_forward_cuda(CudaContext &ctx, const SSElement &attr,
                            const OperatorBase::in_list_t &ins,
                            OperatorBase::out_list_t &outs) {
  const auto &input = ins[0];
  const auto &shift = ins[1];
  auto &output = outs[0];
  cudaStream_t stream = getStreamNative<CudaDevice>(ctx.getStream());
  TINShiftForwardCUDAKernelLauncher(input, shift, output, stream);
}

void tin_shift_backward_cuda(CudaContext &ctx, const SSElement &attr,
                             const OperatorBase::in_list_t &ins,
                             OperatorBase::out_list_t &outs) {
  const auto &grad_output = ins[0];
  const auto &shift = ins[1];
  auto &grad_input = outs[0];
  cudaStream_t stream = getStreamNative<CudaDevice>(ctx.getStream());
  TINShiftBackwardCUDAKernelLauncher(grad_output, shift, grad_input, stream);
}

PARROTS_EXTENSION_REGISTER(tin_shift_forward)
    .input(2)
    .output(1)
    .apply(tin_shift_forward_cuda)
    .done();

PARROTS_EXTENSION_REGISTER(tin_shift_backward)
    .input(2)
    .output(1)
    .apply(tin_shift_backward_cuda)
42
    .done();