cast_transpose.cu 15.7 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include <cuda_runtime.h>
8
#include <transformer_engine/cast_transpose_noop.h>
Przemek Tredak's avatar
Przemek Tredak committed
9
10
#include <transformer_engine/transpose.h>

11
#include <algorithm>
Przemek Tredak's avatar
Przemek Tredak committed
12

13
14
15
16
#include "../common.h"
#include "../util/rtc.h"
#include "../util/string.h"
#include "../utils.cuh"
Przemek Tredak's avatar
Przemek Tredak committed
17

18
namespace transformer_engine {
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
namespace {

// String with RTC kernel implementation
#include "string_code_transpose_rtc_cast_transpose_cu.h"

// Hard-coded kernel parameters
using CType = float;
constexpr size_t warps_per_tile = 4;
constexpr size_t block_size = THREADS_PER_WARP * warps_per_tile;

/* Performance heuristics for optimized kernel parameters */
struct KernelConfig {
  /** Vector load size */
  size_t load_size = 0;
  /** Vector store size to transposed output */
  size_t store_size = 0;

  /* Whether config is valid */
  bool valid = false;
  /* Number of CUDA blocks */
  size_t num_blocks = 0;

  /* Number of active SMs */
  size_t active_sm_count = 0;
  /* Elements per L1 cache load */
  size_t elements_per_load = 0;
  /* Elements per L1 cache store to cast output*/
  size_t elements_per_store_c = 0;
  /* Elements per L1 cache store to transposed output */
  size_t elements_per_store_t = 0;

51
52
53
  KernelConfig(size_t row_length, size_t num_rows, size_t itype_size, size_t otype_size,
               size_t load_size_, size_t store_size_)
      : load_size{load_size_}, store_size{store_size_} {
54
55
    // Check that tiles are correctly aligned
    constexpr size_t cache_line_size = 128;
56
57
    if (load_size % itype_size != 0 || store_size % otype_size != 0 ||
        cache_line_size % itype_size != 0 || cache_line_size % otype_size != 0) {
58
      return;
Przemek Tredak's avatar
Przemek Tredak committed
59
    }
60
61
    const size_t row_tile_elements = load_size * THREADS_PER_WARP / itype_size;
    const size_t col_tile_elements = store_size * THREADS_PER_WARP / otype_size;
62
    valid = (row_length % row_tile_elements == 0 && num_rows % col_tile_elements == 0);
63
64
    if (!valid) {
      return;
Przemek Tredak's avatar
Przemek Tredak committed
65
66
    }

67
68
69
70
71
72
73
    // Number of CUDA blocks
    num_blocks = (row_length / row_tile_elements) * (num_rows / col_tile_elements);

    // Parameters for performance model
    constexpr size_t warps_per_sm = 16;  // Rough estimate for saturated SMs
    active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm),
                               static_cast<size_t>(cuda::sm_count()));
74
75
76
    elements_per_load = (std::min(cache_line_size, row_tile_elements * itype_size) / itype_size);
    elements_per_store_c = (std::min(cache_line_size, row_tile_elements * otype_size) / otype_size);
    elements_per_store_t = (std::min(cache_line_size, col_tile_elements * otype_size) / otype_size);
Przemek Tredak's avatar
Przemek Tredak committed
77
78
  }

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
  /* Compare by estimated cost */
  bool operator<(const KernelConfig &other) const {
    if (this->valid && other.valid) {
      // cost ~ (1/elements_per_load
      //         + 1/elements_per_store_c
      //         + 1/elements_per_store_t) / active_sms
      // Note: Integer arithmetic ensures stable ordering
      const auto &l1 = this->elements_per_load;
      const auto &sc1 = this->elements_per_store_c;
      const auto &st1 = this->elements_per_store_t;
      const auto &p1 = this->active_sm_count;
      const auto &l2 = other.elements_per_load;
      const auto &sc2 = other.elements_per_store_c;
      const auto &st2 = other.elements_per_store_t;
      const auto &p2 = other.active_sm_count;
      const auto scale = l1 * sc1 * st1 * p1 * l2 * sc2 * st2 * p2;
95
96
      const auto cost1 = (scale / l1 + scale / sc1 + scale / st1) / p1;
      const auto cost2 = (scale / l2 + scale / sc2 + scale / st2) / p2;
97
98
99
100
      return cost1 < cost2;
    } else {
      return this->valid && !other.valid;
    }
Przemek Tredak's avatar
Przemek Tredak committed
101
  }
102
};
Przemek Tredak's avatar
Przemek Tredak committed
103

104
template <size_t load_size, size_t store_size, typename IType, typename OType>
105
106
107
108
109
110
111
112
__global__ void __launch_bounds__(block_size)
    cast_transpose_general_kernel(const IType *__restrict__ const input,
                                  const CType *__restrict__ const noop,
                                  OType *__restrict__ const output_c,
                                  OType *__restrict__ const output_t,
                                  const CType *__restrict__ const scale_ptr,
                                  CType *__restrict__ const amax_ptr, const size_t row_length,
                                  const size_t num_rows) {
113
114
  if (noop != nullptr && noop[0] == 1.0f) return;

115
116
117
  // Vectorized load/store sizes
  constexpr size_t nvec_in = load_size / sizeof(IType);
  constexpr size_t nvec_out = store_size / sizeof(OType);
Przemek Tredak's avatar
Przemek Tredak committed
118
  using IVec = Vec<IType, nvec_in>;
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
  using OVecT = Vec<OType, nvec_out>;

  // Thread indices
  // Note: Block is interpreted as a warp_size x num_warps grid
  constexpr size_t bdimx = THREADS_PER_WARP;
  constexpr size_t bdimy = warps_per_tile;
  const size_t tid = threadIdx.x;
  const size_t tidx = tid % bdimx;
  const size_t tidy = tid / bdimx;
  const size_t bid = blockIdx.x;

  // Input tensors are divided into tiles
  // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles
  constexpr size_t tile_dim_m = THREADS_PER_WARP * nvec_out;
  constexpr size_t tile_dim_n = THREADS_PER_WARP * nvec_in;

  // Position of tile within tensor
  const size_t num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m;
  const size_t tile_id_m = bid % num_tiles_m;
  const size_t tile_id_n = bid / num_tiles_m;
  const size_t tile_row = tile_id_m * tile_dim_m;
  const size_t tile_col = tile_id_n * tile_dim_n;

  // Number of nvec_out x nvec_in subtiles for each thread to
  // load/store
  constexpr size_t num_iterations = THREADS_PER_WARP / warps_per_tile;

  // FP8 factors
  const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr;
  CType amax = 0;

  // Load input and store to registers
  // Note: Each thread loads num_iterations subtiles, computes amax,
  // casts type, and transposes in registers.
  OVecT local_output_t[nvec_in][num_iterations];
154
#pragma unroll
155
156
157
  for (size_t iter = 0; iter < num_iterations; ++iter) {
    const size_t i1 = tidy + iter * bdimy;
    const size_t j1 = tidx;
158
#pragma unroll
159
160
161
162
    for (size_t i2 = 0; i2 < nvec_out; ++i2) {
      const size_t row = tile_row + i1 * nvec_out + i2;
      const size_t col = tile_col + j1 * nvec_in;
      if (row < num_rows) {
163
#pragma unroll
164
165
166
167
168
169
170
171
        for (size_t j2 = 0; j2 < nvec_in; ++j2) {
          if (col + j2 < row_length) {
            const CType in = input[row * row_length + col + j2];
            const OType out = OType(in * scale);
            __builtin_assume(amax >= 0);
            amax = fmaxf(fabsf(in), amax);
            output_c[row * row_length + col + j2] = out;
            local_output_t[j2][iter].data.elt[i2] = out;
Przemek Tredak's avatar
Przemek Tredak committed
172
173
          }
        }
174
      }
Przemek Tredak's avatar
Przemek Tredak committed
175
176
177
    }
  }

178
  // Copy transposed output from registers to global memory
179
180
  __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll
181
  for (size_t j2 = 0; j2 < nvec_in; ++j2) {
182
#pragma unroll
183
184
185
186
    for (size_t iter = 0; iter < num_iterations; ++iter) {
      const size_t i1 = tidy + iter * bdimy;
      const size_t j1 = tidx;
      shared_output_t[j1][i1] = local_output_t[j2][iter];
Przemek Tredak's avatar
Przemek Tredak committed
187
188
    }
    __syncthreads();
189
#pragma unroll
190
191
192
193
194
195
    for (size_t iter = 0; iter < num_iterations; ++iter) {
      const size_t i1 = tidx;
      const size_t j1 = tidy + iter * bdimy;
      const size_t row = tile_row + i1 * nvec_out;
      const size_t col = tile_col + j1 * nvec_in + j2;
      if (col < row_length) {
196
#pragma unroll
197
198
199
200
201
        for (size_t i2 = 0; i2 < nvec_out; ++i2) {
          if (row + i2 < num_rows) {
            output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2];
          }
        }
Przemek Tredak's avatar
Przemek Tredak committed
202
203
204
205
206
      }
    }
    __syncthreads();
  }

207
208
209
210
211
212
  // Reduce amax over block
  if (amax_ptr != nullptr) {
    amax = reduce_max<warps_per_tile>(amax, tidy);
    if (threadIdx.x == 0) {
      atomicMaxFloat(amax_ptr, amax);
    }
Przemek Tredak's avatar
Przemek Tredak committed
213
214
215
  }
}

216
217
}  // namespace

218
219
void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output_,
                    Tensor *transposed_output_, cudaStream_t stream) {
220
221
  Tensor &cast_output = *cast_output_;
  Tensor &transposed_output = *transposed_output_;
222

223
  // Check no-op flag
224
  if (noop.data.dptr != nullptr) {
225
    size_t numel = 1;
226
    for (const auto &dim : noop.data.shape) {
227
228
229
      numel *= dim;
    }
    NVTE_CHECK(numel == 1, "Expected 1 element, but found ", numel, ".");
230
231
232
    NVTE_CHECK(noop.data.dtype == DType::kFloat32);
    NVTE_CHECK(noop.data.dptr != nullptr);
  }
233
234
235
236
237

  // Check tensor dims
  CheckInputTensor(input, "cast_transpose_input");
  CheckOutputTensor(cast_output, "cast_output");
  CheckOutputTensor(transposed_output, "transposed_output");
238
  NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
239
  NVTE_CHECK(cast_output.data.shape.size() == 2, "Cast output must have 2 dimensions.");
240
  NVTE_CHECK(transposed_output.data.shape.size() == 2, "Transposed output must have 2 dimensions.");
241
242
  const size_t row_length = input.data.shape[1];
  const size_t num_rows = input.data.shape[0];
243
244
245
246
  NVTE_CHECK(cast_output.data.shape[0] == num_rows, "Wrong dimension of cast output.");
  NVTE_CHECK(cast_output.data.shape[1] == row_length, "Wrong dimension of cast output.");
  NVTE_CHECK(transposed_output.data.shape[0] == row_length,
             "Wrong dimension of transposed output.");
247
  NVTE_CHECK(transposed_output.data.shape[1] == num_rows, "Wrong dimension of transposed output.");
248
249
250
251
252
253
254
255
256
257
258

  // Check tensor pointers
  NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
  NVTE_CHECK(cast_output.data.dptr != nullptr, "Cast output is not allocated.");
  NVTE_CHECK(transposed_output.data.dptr != nullptr, "Transposed output is not allocated.");
  NVTE_CHECK(cast_output.data.dtype == transposed_output.data.dtype,
             "Cast and transposed output types must match.");
  NVTE_CHECK(cast_output.amax.dptr == transposed_output.amax.dptr,
             "Cast and transposed outputs need to share amax tensor.");
  NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr,
             "Cast and transposed outputs need to share scale tensor.");
259

260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
      input.data.dtype, InputType,
      TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
          cast_output.data.dtype, OutputType,
          constexpr const char *itype_name = TypeInfo<InputType>::name;
          constexpr const char *otype_name = TypeInfo<OutputType>::name;
          constexpr size_t itype_size = sizeof(InputType);
          constexpr size_t otype_size = sizeof(OutputType);

          // Choose between runtime-compiled or statically-compiled kernel
          const bool aligned =
              (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0);
          if (aligned && rtc::is_enabled()) {  // Runtime-compiled tuned kernel
            // Pick kernel config
            std::vector<KernelConfig> kernel_configs;
            kernel_configs.reserve(16);
            auto add_config = [&](size_t load_size, size_t store_size) {
              kernel_configs.emplace_back(row_length, num_rows, itype_size, otype_size, load_size,
                                          store_size);
            };
            add_config(8, 8);
            add_config(4, 8);
            add_config(8, 4);
            add_config(4, 4);
            add_config(2, 8);
            add_config(8, 2);
            add_config(2, 4);
            add_config(4, 2);
            add_config(2, 2);
            add_config(1, 8);
            add_config(8, 1);
            add_config(1, 4);
            add_config(4, 1);
            add_config(1, 2);
            add_config(2, 1);
            add_config(1, 1);
            const auto &kernel_config =
                *std::min_element(kernel_configs.begin(), kernel_configs.end());
            NVTE_CHECK(kernel_config.valid, "invalid kernel config");
            const size_t load_size = kernel_config.load_size;
            const size_t store_size = kernel_config.store_size;
            const size_t num_blocks = kernel_config.num_blocks;

            // Compile NVRTC kernel if needed and launch
            auto &rtc_manager = rtc::KernelManager::instance();
            const std::string kernel_label = concat_strings(
                "cast_transpose"
                ",itype=",
                itype_name, ",otype=", otype_name, ",load_size=", load_size,
                ",store_size=", store_size);
            if (!rtc_manager.is_compiled(kernel_label)) {
              std::string code = string_code_transpose_rtc_cast_transpose_cu;
              code = regex_replace(code, "__ITYPE__", itype_name);
              code = regex_replace(code, "__OTYPE__", otype_name);
              code = regex_replace(code, "__LOAD_SIZE__", load_size);
              code = regex_replace(code, "__STORE_SIZE__", store_size);
              code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile);
              code = regex_replace(code, "__BLOCK_SIZE__", block_size);
              rtc_manager.compile(kernel_label, "cast_transpose_optimized_kernel", code,
                                  "transformer_engine/common/transpose/rtc/cast_transpose.cu");
            }
            rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream,
                               static_cast<const InputType *>(input.data.dptr),
                               reinterpret_cast<const CType *>(noop.data.dptr),
                               static_cast<OutputType *>(cast_output.data.dptr),
                               static_cast<OutputType *>(transposed_output.data.dptr),
                               static_cast<const CType *>(cast_output.scale.dptr),
                               static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows);
          } else {  // Statically-compiled general kernel
            constexpr size_t load_size = 4;
            constexpr size_t store_size = 4;
            constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP;
            constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP;
            const int num_blocks =
                (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size));
            cast_transpose_general_kernel<load_size, store_size, InputType, OutputType>
                <<<num_blocks, block_size, 0, stream>>>(
                    static_cast<const InputType *>(input.data.dptr),
                    reinterpret_cast<const CType *>(noop.data.dptr),
                    static_cast<OutputType *>(cast_output.data.dptr),
                    static_cast<OutputType *>(transposed_output.data.dptr),
                    static_cast<const CType *>(cast_output.scale.dptr),
                    static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows);
          });  // NOLINT(*)
  );           // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
345
346
347
348
}

}  // namespace transformer_engine

349
350
void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output,
                         NVTETensor transposed_output, cudaStream_t stream) {
351
  NVTE_API_CALL(nvte_cast_transpose);
Przemek Tredak's avatar
Przemek Tredak committed
352
  using namespace transformer_engine;
353
  auto noop = Tensor();
354
355
356
  cast_transpose(*reinterpret_cast<const Tensor *>(input), noop,
                 reinterpret_cast<Tensor *>(cast_output),
                 reinterpret_cast<Tensor *>(transposed_output), stream);
357
358
}

359
360
void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop,
                                   NVTETensor cast_output, NVTETensor transposed_output,
361
362
363
                                   cudaStream_t stream) {
  NVTE_API_CALL(nvte_cast_transpose_with_noop);
  using namespace transformer_engine;
364
365
366
  cast_transpose(*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(noop),
                 reinterpret_cast<Tensor *>(cast_output),
                 reinterpret_cast<Tensor *>(transposed_output), stream);
Przemek Tredak's avatar
Przemek Tredak committed
367
}