helper_kernel.cu.h 2.27 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
2
#pragma once
// from tensorflow
zhangwenwei's avatar
zhangwenwei committed
3
4
namespace tv {
namespace detail {
zhangwenwei's avatar
zhangwenwei committed
5
6

template <typename T>
zhangwenwei's avatar
zhangwenwei committed
7
8
9
10
class KernelLoop {
  struct Iterator {
    __forceinline__ __device__ Iterator(T index, T delta)
        : index_(index), delta_(delta) {}
zhangwenwei's avatar
zhangwenwei committed
11
    __forceinline__ __device__ T operator*() const { return index_; }
zhangwenwei's avatar
zhangwenwei committed
12
    __forceinline__ __device__ Iterator &operator++() {
zhangwenwei's avatar
zhangwenwei committed
13
14
15
      index_ += delta_;
      return *this;
    }
zhangwenwei's avatar
zhangwenwei committed
16
    __forceinline__ __device__ bool operator!=(const Iterator &other) const {
zhangwenwei's avatar
zhangwenwei committed
17
18
19
20
      bool greater = index_ > other.index_;
      bool less = index_ < other.index_;
      // Anything past an end iterator (delta_ == 0) is equal.
      // In range-based for loops, this optimizes to 'return less'.
zhangwenwei's avatar
zhangwenwei committed
21
      if (!other.delta_) {
zhangwenwei's avatar
zhangwenwei committed
22
23
        return less;
      }
zhangwenwei's avatar
zhangwenwei committed
24
      if (!delta_) {
zhangwenwei's avatar
zhangwenwei committed
25
26
27
28
29
        return greater;
      }
      return less || greater;
    }

zhangwenwei's avatar
zhangwenwei committed
30
   private:
zhangwenwei's avatar
zhangwenwei committed
31
32
33
34
    T index_;
    const T delta_;
  };

zhangwenwei's avatar
zhangwenwei committed
35
 public:
zhangwenwei's avatar
zhangwenwei committed
36
37
38
  __forceinline__ __device__ KernelLoop(T begin, T delta, T end)
      : begin_(begin), delta_(delta), end_(end) {}

zhangwenwei's avatar
zhangwenwei committed
39
40
41
  __forceinline__ __device__ Iterator begin() const {
    return Iterator{begin_, delta_};
  }
zhangwenwei's avatar
zhangwenwei committed
42
43
  __forceinline__ __device__ Iterator end() const { return Iterator{end_, 0}; }

zhangwenwei's avatar
zhangwenwei committed
44
 private:
zhangwenwei's avatar
zhangwenwei committed
45
46
47
48
49
  T begin_;
  T delta_;
  T end_;
};

zhangwenwei's avatar
zhangwenwei committed
50
51
52
}  // namespace detail
template <typename T, int NumILP = 1>
__forceinline__ __device__ detail::KernelLoop<T> KernelLoopX(T count) {
zhangwenwei's avatar
zhangwenwei committed
53
  return detail::KernelLoop<T>(blockIdx.x * blockDim.x + threadIdx.x,
zhangwenwei's avatar
zhangwenwei committed
54
                               gridDim.x * blockDim.x * NumILP, count);
zhangwenwei's avatar
zhangwenwei committed
55
56
57
58
}

// Helper to visit indices in the range 0 <= i < count using the y-coordinate.
// Usage: for(int i : KernelLoopY(count)) { visit(i); }
zhangwenwei's avatar
zhangwenwei committed
59
60
template <typename T, int NumILP = 1>
__forceinline__ __device__ detail::KernelLoop<T> KernelLoopY(T count) {
zhangwenwei's avatar
zhangwenwei committed
61
  return detail::KernelLoop<T>(blockIdx.y * blockDim.y + threadIdx.y,
zhangwenwei's avatar
zhangwenwei committed
62
                               gridDim.y * blockDim.y * NumILP, count);
zhangwenwei's avatar
zhangwenwei committed
63
64
65
66
}

// Helper to visit indices in the range 0 <= i < count using the z-coordinate.
// Usage: for(int i : KernelLoopZ(count)) { visit(i); }
zhangwenwei's avatar
zhangwenwei committed
67
68
template <typename T, int NumILP = 1>
__forceinline__ __device__ detail::KernelLoop<T> KernelLoopZ(T count) {
zhangwenwei's avatar
zhangwenwei committed
69
  return detail::KernelLoop<T>(blockIdx.z * blockDim.z + threadIdx.z,
zhangwenwei's avatar
zhangwenwei committed
70
                               gridDim.z * blockDim.z * NumILP, count);
zhangwenwei's avatar
zhangwenwei committed
71
72
}

zhangwenwei's avatar
zhangwenwei committed
73
}  // namespace tv