reduce.h 6.22 KB
Newer Older
1
2
3
4
5
6
7
#pragma once

#include "common.h"

namespace tl {

struct SumOp {
8
  template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
9
10
11
12
13
    return x + y;
  }
};

struct MaxOp {
14
  template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
15
16
17
18
19
    return cutlass::fast_max(x, y);
  }
};

struct MinOp {
20
  template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
21
22
23
24
    return cutlass::fast_min(x, y);
  }
};

25
26
template <class Reducer, int threads, int scale, int thread_offset = 0,
          int all_threads = threads>
27
struct AllReduce {
28
29
30
  static_assert(threads == 1024 or threads == 512 or threads == 256 or
                threads == 128 or threads == 64 or threads == 32 or
                threads == 16 or threads == 8 or threads == 4 or threads == 2);
31
  static_assert(threads % scale == 0);
32
  template <typename T> static TL_DEVICE T run(T x, T *red_buf = nullptr) {
33
34
35
    constexpr int offset = threads / 2;
    if constexpr (offset >= 32) {
      __syncthreads();
36
      red_buf[threadIdx.x - thread_offset] = x;
37
      __syncthreads();
38
      x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
39
40
41
42
43
44
    } else {
      x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
    }
    if constexpr (offset == scale) {
      return x;
    } else {
45
46
      return AllReduce<Reducer, offset, scale, thread_offset, all_threads>::run(
          x, red_buf);
47
48
    }
  }
49
50
51
52
53

  template <typename T>
  static TL_DEVICE T run_hopper(T x, T *red_buf = nullptr) {
    constexpr int offset = threads / 2;
    if constexpr (offset >= 32) {
54
      asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(all_threads));
55
      red_buf[threadIdx.x - thread_offset] = x;
56
      // TODO(lei): maybe we can merge the two bar.sync into one?
57
      asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads));
58
      x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
59
60
61
62
63
64
    } else {
      x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
    }
    if constexpr (offset == scale) {
      return x;
    } else {
65
66
      return AllReduce<Reducer, offset, scale, thread_offset,
                       all_threads>::run_hopper(x, red_buf);
67
68
    }
  }
69
70
};

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
template <int threads, bool reverse = false> struct CumSum1D {
  static_assert(threads == 1024 or threads == 512 or threads == 256 or
                threads == 128 or threads == 64 or threads == 32);
  template <typename T, int SEG = 32>
  static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst,
                            int N) {
    if (N <= 0)
      return;

    constexpr unsigned MASK = 0xffffffff;
    const int tid = threadIdx.x;
    const int lane = tid % SEG;

    if (tid >= SEG)
      return;

    T carry = (T)0;

    if (reverse) {
      const int num_segments = (N + SEG - 1) / SEG;
      for (int seg = num_segments - 1; seg >= 0; --seg) {
        const int idx = seg * SEG + lane;
        T val = (idx < N) ? src[idx] : (T)0;

#pragma unroll
        for (int off = 1; off < SEG; off <<= 1) {
          T n = (T)__shfl_down_sync(MASK, val, off);
          if (lane < SEG - off)
            val += n;
        }

        val += carry;

        if (idx < N)
          dst[idx] = val;

        T segSum = (T)__shfl_sync(MASK, val, 0);
        if (lane == 0)
          carry = segSum;
        carry = (T)__shfl_sync(MASK, carry, 0);
      }
    } else {
      const int num_segments = (N + SEG - 1) / SEG;
      for (int seg = 0; seg < num_segments; ++seg) {
        const int idx = seg * SEG + lane;
        T val = (idx < N) ? src[idx] : (T)0;

#pragma unroll
        for (int off = 1; off < SEG; off <<= 1) {
          T n = (T)__shfl_up_sync(MASK, val, off);
          if (lane >= off)
            val += n;
        }

        val += carry;

        if (idx < N)
          dst[idx] = val;

        T segSum = (T)__shfl_sync(MASK, val, SEG - 1);
        if (lane == SEG - 1)
          carry = segSum;
        carry = (T)__shfl_sync(MASK, carry, SEG - 1);
      }
    }
  }
};

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
  static_assert(threads == 1024 or threads == 512 or threads == 256 or
                threads == 128 or threads == 64 or threads == 32);
  template <typename T, int SEG = 32>
  static TL_DEVICE T run(const T *__restrict__ src, T *__restrict__ dst, int H,
                         int W) {

    constexpr int TILE_H = threads / SEG;
    constexpr unsigned MASK = 0xffffffff;
    const int num_blocks = (H + TILE_H - 1) / TILE_H;
    const int tid = threadIdx.x;
    const int lane = tid % 32;
    const int row = tid / 32;

    for (int b = 0; b < num_blocks; ++b) {
      const int gRow = b * TILE_H + row;
      if (gRow >= H)
        return;

      T carry = (T)0;

      if (reverse) {
        // Start from the last segment for reverse mode
        for (int seg = (W + SEG - 1) / SEG - 1; seg >= 0; --seg) {
          const int col = seg * SEG + lane;

          const int real_row = Axis == 1 ? gRow : col;
          const int real_col = Axis == 1 ? col : gRow;

          T val = (col < W) ? src[real_row * W + real_col] : (T)0;

#pragma unroll
          for (int off = 1; off < SEG; off <<= 1) {
            T n = (T)__shfl_down_sync(MASK, val, off);
            if (lane < SEG - off)
              val += n;
          }

          val += carry;

          if (real_col < W)
            dst[real_row * W + real_col] = val;

          T segSum = (T)__shfl_sync(MASK, val, (T)0);
          if (lane == 0)
            carry = segSum;
          carry = (T)__shfl_sync(MASK, carry, (T)0);
        }
      } else {
        for (int seg = 0; seg * SEG < W; ++seg) {
          const int col = seg * SEG + lane;

          const int real_row = Axis == 1 ? gRow : col;
          const int real_col = Axis == 1 ? col : gRow;

          T val = (col < W) ? src[real_row * W + real_col] : (T)0;

#pragma unroll
          for (int off = 1; off < SEG; off <<= 1) {
            T n = (T)__shfl_up_sync(MASK, val, off);
            if (lane >= off)
              val += n;
          }

          val += carry;

          if (real_col < W)
            dst[real_row * W + real_col] = val;

          T segSum = (T)__shfl_sync(MASK, val, SEG - 1);
          if (lane == SEG - 1)
            carry = segSum;
          carry = (T)__shfl_sync(MASK, carry, SEG - 1);
        }
      }
    }
  }
};

218
} // namespace tl