"src/targets/gpu/vscode:/vscode.git/clone" did not exist on "37339d1ade03c1c8754051ccd0efec1810d6eb48"
reduce.h 6.91 KB
Newer Older
1
2
3
#pragma once

#include "common.h"
4
5
#include <cstdint>
#include <type_traits>
6
7
8

namespace tl {

9
10
11
12
13
14
15
16
17
18
19
20
// Select a wider accumulator type for improved numerical accuracy.
// Default: accumulate in the same type. Specialize FP16/BF16 to float.
template <typename T> struct AccType {
  using type = T;
};
template <> struct AccType<half_t> {
  using type = float;
};
template <> struct AccType<bfloat16_t> {
  using type = float;
};

21
struct SumOp {
22
  template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
23
24
25
26
27
    return x + y;
  }
};

struct MaxOp {
28
  template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
29
30
31
32
33
    return cutlass::fast_max(x, y);
  }
};

struct MinOp {
34
  template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
35
36
37
38
    return cutlass::fast_min(x, y);
  }
};

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
struct BitAndOp {
  template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
    return x & y;
  }
};

struct BitOrOp {
  template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
    return x | y;
  }
};

struct BitXorOp {
  template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
    return x ^ y;
  }
};

57
58
template <class Reducer, int threads, int scale, int thread_offset = 0,
          int all_threads = threads>
59
struct AllReduce {
60
61
62
  static_assert(threads == 1024 or threads == 512 or threads == 256 or
                threads == 128 or threads == 64 or threads == 32 or
                threads == 16 or threads == 8 or threads == 4 or threads == 2);
63
  static_assert(threads % scale == 0);
64
  template <typename T> static TL_DEVICE T run(T x, T *red_buf = nullptr) {
65
66
67
    constexpr int offset = threads / 2;
    if constexpr (offset >= 32) {
      __syncthreads();
68
      red_buf[threadIdx.x - thread_offset] = x;
69
      __syncthreads();
70
      x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
71
    } else {
72
      x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
73
74
75
76
    }
    if constexpr (offset == scale) {
      return x;
    } else {
77
78
      return AllReduce<Reducer, offset, scale, thread_offset, all_threads>::run(
          x, red_buf);
79
80
    }
  }
81
82
83
84
85

  template <typename T>
  static TL_DEVICE T run_hopper(T x, T *red_buf = nullptr) {
    constexpr int offset = threads / 2;
    if constexpr (offset >= 32) {
86
      asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(all_threads));
87
      red_buf[threadIdx.x - thread_offset] = x;
88
      // TODO(lei): maybe we can merge the two bar.sync into one?
89
      asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads));
90
      x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
91
    } else {
92
      x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
93
94
95
96
    }
    if constexpr (offset == scale) {
      return x;
    } else {
97
98
      return AllReduce<Reducer, offset, scale, thread_offset,
                       all_threads>::run_hopper(x, red_buf);
99
100
    }
  }
101
102
};

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
template <int threads, bool reverse = false> struct CumSum1D {
  static_assert(threads == 1024 or threads == 512 or threads == 256 or
                threads == 128 or threads == 64 or threads == 32);
  template <typename T, int SEG = 32>
  static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst,
                            int N) {
    if (N <= 0)
      return;

    constexpr unsigned MASK = 0xffffffff;
    const int tid = threadIdx.x;
    const int lane = tid % SEG;

    if (tid >= SEG)
      return;

    T carry = (T)0;

    if (reverse) {
      const int num_segments = (N + SEG - 1) / SEG;
      for (int seg = num_segments - 1; seg >= 0; --seg) {
        const int idx = seg * SEG + lane;
        T val = (idx < N) ? src[idx] : (T)0;

#pragma unroll
        for (int off = 1; off < SEG; off <<= 1) {
LJC00118's avatar
LJC00118 committed
129
          T n = (T)tl::shfl_down_sync(MASK, val, off);
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
          if (lane < SEG - off)
            val += n;
        }

        val += carry;

        if (idx < N)
          dst[idx] = val;

        T segSum = (T)__shfl_sync(MASK, val, 0);
        if (lane == 0)
          carry = segSum;
        carry = (T)__shfl_sync(MASK, carry, 0);
      }
    } else {
      const int num_segments = (N + SEG - 1) / SEG;
      for (int seg = 0; seg < num_segments; ++seg) {
        const int idx = seg * SEG + lane;
        T val = (idx < N) ? src[idx] : (T)0;

#pragma unroll
        for (int off = 1; off < SEG; off <<= 1) {
          T n = (T)__shfl_up_sync(MASK, val, off);
          if (lane >= off)
            val += n;
        }

        val += carry;

        if (idx < N)
          dst[idx] = val;

        T segSum = (T)__shfl_sync(MASK, val, SEG - 1);
        if (lane == SEG - 1)
          carry = segSum;
        carry = (T)__shfl_sync(MASK, carry, SEG - 1);
      }
    }
  }
};

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
  static_assert(threads == 1024 or threads == 512 or threads == 256 or
                threads == 128 or threads == 64 or threads == 32);
  template <typename T, int SEG = 32>
  static TL_DEVICE T run(const T *__restrict__ src, T *__restrict__ dst, int H,
                         int W) {

    constexpr int TILE_H = threads / SEG;
    constexpr unsigned MASK = 0xffffffff;
    const int num_blocks = (H + TILE_H - 1) / TILE_H;
    const int tid = threadIdx.x;
    const int lane = tid % 32;
    const int row = tid / 32;

    for (int b = 0; b < num_blocks; ++b) {
      const int gRow = b * TILE_H + row;
      if (gRow >= H)
        return;

      T carry = (T)0;

      if (reverse) {
        // Start from the last segment for reverse mode
        for (int seg = (W + SEG - 1) / SEG - 1; seg >= 0; --seg) {
          const int col = seg * SEG + lane;

          const int real_row = Axis == 1 ? gRow : col;
          const int real_col = Axis == 1 ? col : gRow;

          T val = (col < W) ? src[real_row * W + real_col] : (T)0;

#pragma unroll
          for (int off = 1; off < SEG; off <<= 1) {
204
            T n = tl::shfl_down_sync(MASK, val, off);
205
206
207
208
209
210
211
212
213
            if (lane < SEG - off)
              val += n;
          }

          val += carry;

          if (real_col < W)
            dst[real_row * W + real_col] = val;

214
          T segSum = tl::shfl_sync(MASK, val, 0);
215
216
          if (lane == 0)
            carry = segSum;
217
          carry = tl::shfl_sync(MASK, carry, 0);
218
219
220
221
222
223
224
225
226
227
228
229
        }
      } else {
        for (int seg = 0; seg * SEG < W; ++seg) {
          const int col = seg * SEG + lane;

          const int real_row = Axis == 1 ? gRow : col;
          const int real_col = Axis == 1 ? col : gRow;

          T val = (col < W) ? src[real_row * W + real_col] : (T)0;

#pragma unroll
          for (int off = 1; off < SEG; off <<= 1) {
230
            T n = tl::shfl_up_sync(MASK, val, off);
231
232
233
234
235
236
237
238
239
            if (lane >= off)
              val += n;
          }

          val += carry;

          if (real_col < W)
            dst[real_row * W + real_col] = val;

240
          T segSum = tl::shfl_sync(MASK, val, SEG - 1);
241
242
          if (lane == SEG - 1)
            carry = segSum;
243
          carry = tl::shfl_sync(MASK, carry, SEG - 1);
244
245
246
247
248
249
        }
      }
    }
  }
};

250
} // namespace tl