utils.h 11.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*************************************************************************
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_

#include "transformer_engine/transformer_engine.h"

yuguo's avatar
yuguo committed
12
13
14
15
#ifdef __HIP_PLATFORM_AMD__
#define __syncwarp __syncthreads
#endif

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
namespace transformer_engine {

constexpr size_t kThreadsPerWarp = 32;
constexpr int kThreadsPerBlock =
    128;  // Using 4 warps in 1 CTA, Each warp is responsible for 1 token.
constexpr float epsilon = 1e-20;

template <typename T>
__device__ inline T max(T a, T b) {
  return a > b ? a : b;
}

template <typename T>
__device__ inline T sum(T a, T b) {
  return a + b;
}

33
34
35
36
37
enum ReduceFuncType {
  SUM,
  MAX,
};

38
template <typename T>
39
__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncType type,
40
                                         int lane_id) {
41
42
43
44
45
46
47
48
49
50
  T (*reduce_func)(T, T);
  double default_val = 0;
  if (type == ReduceFuncType::SUM) {
    reduce_func = sum;
    default_val = 0;
  } else if (type == ReduceFuncType::MAX) {
    reduce_func = max;
    default_val = -std::numeric_limits<double>::infinity();
  }

51
52
53
  // Some value is hanlded in local thread
  // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
  // Reduce the value in local thread
54
  volatile double val = lane_id < data_size ? static_cast<double>(data_ptr[lane_id]) : default_val;
55
56
57
58
59
  for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
    val = reduce_func(val, data_ptr[i]);
  }

  // Warp shuffle between threads
yuguo's avatar
yuguo committed
60
61
62
63
64
65
66
#ifdef __HIP_PLATFORM_AMD__
  val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 16, kThreadsPerWarp));
  val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 8, kThreadsPerWarp));
  val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 4, kThreadsPerWarp));
  val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 2, kThreadsPerWarp));
  val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 1, kThreadsPerWarp));
#else
67
68
69
70
71
  val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16));
  val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8));
  val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4));
  val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2));
  val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1));
yuguo's avatar
yuguo committed
72
#endif
73
74
75
76
77
78
79
80
81
82
83
84
85
  __syncwarp();
  return T(val);
}

template <typename DataType>
__device__ inline void apply_sigmoid_on_float(DataType *scores, int data_size, int lane_id) {
  for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
    scores[i] = static_cast<float>(1.0f / (1.0f + exp(-static_cast<float>(scores[i]))));
  }
}

template <typename T>
__device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int data_size,
86
87
88
89
90
91
92
93
94
95
96
                                                ReduceFuncType type, int lane_id) {
  T (*reduce_func)(T, T);
  double default_val = 0;
  if (type == ReduceFuncType::SUM) {
    reduce_func = sum;
    default_val = 0;
  } else if (type == ReduceFuncType::MAX) {
    reduce_func = max;
    default_val = -std::numeric_limits<double>::infinity();
  }

97
98
99
  // Some value is hanlded in local thread
  // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
  // Reduce the value in local thread
100
101
  volatile double val =
      lane_id < data_size && mask[lane_id] ? static_cast<double>(data_ptr[lane_id]) : default_val;
102
103
104
105
106
107
108
  for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
    if (mask[i]) {
      val = reduce_func(val, data_ptr[i]);
    }
  }

  // Warp shuffle between threads
yuguo's avatar
yuguo committed
109
110
111
112
113
114
115
#ifdef __HIP_PLATFORM_AMD__
  val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 16, kThreadsPerWarp));
  val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 8, kThreadsPerWarp));
  val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 4, kThreadsPerWarp));
  val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 2, kThreadsPerWarp));
  val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 1, kThreadsPerWarp));
#else
116
117
118
119
120
  val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16));
  val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8));
  val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4));
  val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2));
  val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1));
yuguo's avatar
yuguo committed
121
#endif
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
  __syncwarp();
  return T(val);
}

template <typename DataType>
__device__ inline void apply_sigmoid_bwd_on_float(DataType *grad, DataType *fwd_output,
                                                  int data_size, int lane_id) {
  for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
    grad[i] = static_cast<double>(grad[i]) * static_cast<double>(fwd_output[i]) *
              (1 - static_cast<double>(fwd_output[i]));
  }
}

template <typename DataType>
__device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_output,
                                                  DataType *comp_buf, bool *mask, int data_size,
                                                  int lane_id) {
  // Put the result of output * grad to the comp_buf
  for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
    if (mask) {
      if (mask[i])
        comp_buf[i] = static_cast<float>(grad[i]) * static_cast<float>(fwd_output[i]);
      else
        comp_buf[i] = 0.0f;
    } else {
      comp_buf[i] = static_cast<float>(grad[i]) * static_cast<float>(fwd_output[i]);
    }
  }
  __syncwarp();
  float sum_Output_x_Grad = warp_reduce_on_shmem(
      /*data ptr = */ comp_buf,
      /*data size = */ data_size,
154
      /*reduce func = */ ReduceFuncType::SUM, lane_id);
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
  // In-place update
  for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
    if (mask) {
      if (mask[i])
        grad[i] =
            static_cast<float>(fwd_output[i]) * (static_cast<float>(grad[i]) - sum_Output_x_Grad);
      else
        grad[i] = 0.0f;
    } else {
      grad[i] =
          static_cast<float>(fwd_output[i]) * (static_cast<float>(grad[i]) - sum_Output_x_Grad);
    }
  }
}

template <typename DataType>
__device__ inline void apply_softmax_on_float(DataType *scores, int data_size, int lane_id) {
  // 1. compute the max of value
173
174
  float max_val =
      static_cast<float>(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::MAX, lane_id));
175
176
177
178
179
180
  // 2. value -> exp_value
  for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
    scores[i] = static_cast<float>(exp(static_cast<float>(scores[i]) - max_val));
  }
  __syncwarp();
  // 3. compute the sum of exp_value
181
182
  float sum_val =
      static_cast<float>(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::SUM, lane_id));
183
184
185
186
187
188
189
190
191
192
  // 4. update the softmax value
  for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
    scores[i] = static_cast<float>(scores[i]) / sum_val;
  }
  __syncwarp();
}

template <typename T>
__device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices,
                                           T *topk_scores, int lane_id) {
193
194
195
196
197
198
199
200
  // Check if the index is masked by the later iteration
  auto is_masked = [&topk_indices](int k, int index) {
    if (k == 0) return false;
    for (int i = 0; i < k; i++) {
      if (topk_indices[i] == index) return true;
    }
    return false;
  };
201
202
203
204
205
  // Topk Times: Find the max value and its index
  // Then mask it, and record the index in the topk_indices
  // After looping topk times, the topk_indices will be the topk indices
  for (int k = 0; k < topk; k++) {
    // Find the max value and its index
206
207
208
    volatile double val = (lane_id < data_size && !is_masked(k, lane_id))
                              ? static_cast<double>(scores[lane_id])
                              : -std::numeric_limits<double>::infinity();
209
210
211
212
213
    volatile int index = (lane_id < data_size) ? lane_id : 0;
    // Some value is hanlded in local thread
    // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
    // Reduce the value in local thread
    for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
214
215
      volatile double cur_val = (is_masked(k, i)) ? -std::numeric_limits<double>::infinity()
                                                  : static_cast<double>(scores[i]);
216
217
218
219
220
221
222
      if (cur_val > val) {
        val = cur_val;
        index = i;
      }
    }
    // Warp shuffle between threads
    for (int s = 16; s > 0; s /= 2) {
yuguo's avatar
yuguo committed
223
224
225
226
#ifdef __HIP_PLATFORM_AMD__
      volatile auto shuffled_val = __shfl_xor_sync((unsigned long long)0xffffffff, val, s, kThreadsPerWarp);
      volatile auto shuffled_index = __shfl_xor_sync((unsigned long long)0xffffffff, index, s, kThreadsPerWarp);
#else
227
228
      volatile auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s);
      volatile auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s);
yuguo's avatar
yuguo committed
229
#endif
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
      if (shuffled_val > val) {
        val = shuffled_val;
        index = shuffled_index;
      }
    }
    if (lane_id == 0) {
      topk_indices[k] = index;
      topk_scores[k] = val;
    }
    __syncwarp();
  }
}

// Current TE only support float32/bf16/fp16, float64 probs should be considered in the future
#define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \
  switch (dtype) {                                        \
    using namespace transformer_engine;                   \
    case DType::kFloat32: {                               \
      using type = float;                                 \
      { __VA_ARGS__ }                                     \
    } break;                                              \
    case DType::kFloat16: {                               \
      using type = fp16;                                  \
      { __VA_ARGS__ }                                     \
    } break;                                              \
    case DType::kBFloat16: {                              \
      using type = bf16;                                  \
      { __VA_ARGS__ }                                     \
    } break;                                              \
    default:                                              \
      NVTE_ERROR("Invalid type.");                        \
  }

#define TE_ROUTER_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \
  switch (dtype) {                                        \
    using namespace transformer_engine;                   \
    case DType::kInt32: {                                 \
      using type = int32_t;                               \
      { __VA_ARGS__ }                                     \
    } break;                                              \
    case DType::kInt64: {                                 \
      using type = int64_t;                               \
      { __VA_ARGS__ }                                     \
    } break;                                              \
    default:                                              \
      NVTE_ERROR("Invalid type.");                        \
  }
}  // namespace transformer_engine
#endif