kernel_utils.h 2.26 KB
Newer Older
traveller59's avatar
traveller59 committed
1
2
#pragma once
// from tensorflow
3
4
namespace tv {
namespace detail {
traveller59's avatar
traveller59 committed
5

6
7
8
9
template <typename T> class KernelLoop {
  struct Iterator {
    __forceinline__ __device__ Iterator(T index, T delta)
        : index_(index), delta_(delta) {}
traveller59's avatar
traveller59 committed
10
    __forceinline__ __device__ T operator*() const { return index_; }
11
    __forceinline__ __device__ Iterator &operator++() {
traveller59's avatar
traveller59 committed
12
13
14
      index_ += delta_;
      return *this;
    }
15
    __forceinline__ __device__ bool operator!=(const Iterator &other) const {
traveller59's avatar
traveller59 committed
16
17
18
19
      bool greater = index_ > other.index_;
      bool less = index_ < other.index_;
      // Anything past an end iterator (delta_ == 0) is equal.
      // In range-based for loops, this optimizes to 'return less'.
20
      if (!other.delta_) {
traveller59's avatar
traveller59 committed
21
22
        return less;
      }
23
      if (!delta_) {
traveller59's avatar
traveller59 committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
        return greater;
      }
      return less || greater;
    }

  private:
    T index_;
    const T delta_;
  };

public:
  __forceinline__ __device__ KernelLoop(T begin, T delta, T end)
      : begin_(begin), delta_(delta), end_(end) {}

38
39
40
  __forceinline__ __device__ Iterator begin() const {
    return Iterator{begin_, delta_};
  }
traveller59's avatar
traveller59 committed
41
42
43
44
45
46
47
48
49
  __forceinline__ __device__ Iterator end() const { return Iterator{end_, 0}; }

private:
  T begin_;
  T delta_;
  T end_;
};

} // namespace detail
50
51
template <typename T, int NumILP = 1>
__forceinline__ __device__ detail::KernelLoop<T> KernelLoopX(T count) {
traveller59's avatar
traveller59 committed
52
  return detail::KernelLoop<T>(blockIdx.x * blockDim.x + threadIdx.x,
53
                               gridDim.x * blockDim.x * NumILP, count);
traveller59's avatar
traveller59 committed
54
55
56
57
}

// Helper to visit indices in the range 0 <= i < count using the y-coordinate.
// Usage: for(int i : KernelLoopY(count)) { visit(i); }
58
59
template <typename T, int NumILP = 1>
__forceinline__ __device__ detail::KernelLoop<T> KernelLoopY(T count) {
traveller59's avatar
traveller59 committed
60
  return detail::KernelLoop<T>(blockIdx.y * blockDim.y + threadIdx.y,
61
                               gridDim.y * blockDim.y * NumILP, count);
traveller59's avatar
traveller59 committed
62
63
64
65
}

// Helper to visit indices in the range 0 <= i < count using the z-coordinate.
// Usage: for(int i : KernelLoopZ(count)) { visit(i); }
66
67
template <typename T, int NumILP = 1>
__forceinline__ __device__ detail::KernelLoop<T> KernelLoopZ(T count) {
traveller59's avatar
traveller59 committed
68
  return detail::KernelLoop<T>(blockIdx.z * blockDim.z + threadIdx.z,
69
                               gridDim.z * blockDim.z * NumILP, count);
traveller59's avatar
traveller59 committed
70
71
72
}

} // namespace tv