transpose.cu 12.1 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include <cuda_runtime.h>
8
#include <transformer_engine/cast_transpose_noop.h>
Przemek Tredak's avatar
Przemek Tredak committed
9
#include <transformer_engine/transpose.h>
10
11
12

#include <algorithm>

Przemek Tredak's avatar
Przemek Tredak committed
13
#include "../common.h"
Tim Moon's avatar
Tim Moon committed
14
#include "../util/rtc.h"
15
16
#include "../util/string.h"
#include "../utils.cuh"
17
#include "./transpose.h"
Przemek Tredak's avatar
Przemek Tredak committed
18
19

namespace transformer_engine {
20
namespace detail {
Przemek Tredak's avatar
Przemek Tredak committed
21

Tim Moon's avatar
Tim Moon committed
22
namespace {
Przemek Tredak's avatar
Przemek Tredak committed
23

Tim Moon's avatar
Tim Moon committed
24
25
// String with RTC kernel implementation
#include "string_code_transpose_rtc_transpose_cu.h"
Przemek Tredak's avatar
Przemek Tredak committed
26

Tim Moon's avatar
Tim Moon committed
27
28
29
// Hard-coded kernel parameters
constexpr size_t warps_per_tile = 4;
constexpr size_t block_size = THREADS_PER_WARP * warps_per_tile;
Przemek Tredak's avatar
Przemek Tredak committed
30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
/* Performance heuristics for optimized kernel parameters */
struct KernelConfig {
  /** Vector load size */
  size_t load_size;
  /** Vector store size */
  size_t store_size;

  /* Whether config is valid */
  bool valid = false;
  /* Number of CUDA blocks */
  size_t num_blocks = 0;

  /* Number of active SMs */
  size_t active_sm_count = 0;
  /* Elements per L1 cache load */
  size_t elements_per_load = 0;
  /* Elements per L1 cache store */
  size_t elements_per_store = 0;

50
  KernelConfig(size_t row_length, size_t num_rows, size_t type_size, size_t load_size_,
51
               size_t store_size_, size_t sm_count)
52
      : load_size{load_size_}, store_size{store_size_} {
53
54
    // Check that tiles are correctly aligned
    constexpr size_t cache_line_size = 128;
55
56
    if (load_size % type_size != 0 || store_size % type_size != 0 ||
        cache_line_size % type_size != 0) {
57
58
59
60
      return;
    }
    const size_t row_tile_elements = load_size * THREADS_PER_WARP / type_size;
    const size_t col_tile_elements = store_size * THREADS_PER_WARP / type_size;
61
    valid = (row_length % row_tile_elements == 0 && num_rows % col_tile_elements == 0);
62
63
64
65
66
67
68
69
70
    if (!valid) {
      return;
    }

    // Number of CUDA blocks
    num_blocks = (row_length / row_tile_elements) * (num_rows / col_tile_elements);

    // Parameters for performance model
    constexpr size_t warps_per_sm = 16;  // Rough estimate for saturated SMs
71
    active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm), sm_count);
72
73
    elements_per_load = (std::min(cache_line_size, row_tile_elements * type_size) / type_size);
    elements_per_store = (std::min(cache_line_size, col_tile_elements * type_size) / type_size);
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  }

  /* Compare by estimated cost */
  bool operator<(const KernelConfig &other) const {
    if (this->valid && other.valid) {
      // cost ~ (1/elements_per_load + 1/elements_per_store) / active_sms
      // Note: Integer arithmetic ensures stable ordering
      const auto &l1 = this->elements_per_load;
      const auto &s1 = this->elements_per_store;
      const auto &p1 = this->active_sm_count;
      const auto &l2 = other.elements_per_load;
      const auto &s2 = other.elements_per_store;
      const auto &p2 = other.active_sm_count;
      const auto scale = l1 * s1 * p1 * l2 * s2 * p2;
88
89
      const auto cost1 = (scale / l1 + scale / s1) / p1;
      const auto cost2 = (scale / l2 + scale / s2) / p2;
90
91
92
93
94
95
      return cost1 < cost2;
    } else {
      return this->valid && !other.valid;
    }
  }
};
Przemek Tredak's avatar
Przemek Tredak committed
96

Tim Moon's avatar
Tim Moon committed
97
template <size_t load_size, size_t store_size, typename Type>
98
99
100
101
__global__ void __launch_bounds__(block_size)
    transpose_general_kernel(const Type *__restrict__ const input, const fp32 *const noop,
                             Type *__restrict__ const output, const size_t row_length,
                             const size_t num_rows) {
102
103
  if (noop != nullptr && noop[0] == 1.0f) return;

Tim Moon's avatar
Tim Moon committed
104
105
106
107
108
  // Vectorized load/store sizes
  constexpr size_t nvec_in = load_size / sizeof(Type);
  constexpr size_t nvec_out = store_size / sizeof(Type);
  using IVec = Vec<Type, nvec_in>;
  using OVec = Vec<Type, nvec_out>;
Przemek Tredak's avatar
Przemek Tredak committed
109

Tim Moon's avatar
Tim Moon committed
110
111
112
113
114
115
116
117
  // Thread indices
  // Note: Block is interpreted as a warp_size x num_warps grid
  constexpr size_t bdimx = THREADS_PER_WARP;
  constexpr size_t bdimy = warps_per_tile;
  const size_t tid = threadIdx.x;
  const size_t tidx = tid % bdimx;
  const size_t tidy = tid / bdimx;
  const size_t bid = blockIdx.x;
Przemek Tredak's avatar
Przemek Tredak committed
118

Tim Moon's avatar
Tim Moon committed
119
120
121
122
  // Input tensors are divided into tiles
  // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles
  constexpr size_t tile_dim_m = THREADS_PER_WARP * nvec_out;
  constexpr size_t tile_dim_n = THREADS_PER_WARP * nvec_in;
Przemek Tredak's avatar
Przemek Tredak committed
123

Tim Moon's avatar
Tim Moon committed
124
125
126
127
128
129
  // Position of tile within tensor
  const size_t num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m;
  const size_t tile_id_m = bid % num_tiles_m;
  const size_t tile_id_n = bid / num_tiles_m;
  const size_t tile_row = tile_id_m * tile_dim_m;
  const size_t tile_col = tile_id_n * tile_dim_n;
Przemek Tredak's avatar
Przemek Tredak committed
130

Tim Moon's avatar
Tim Moon committed
131
132
133
  // Number of nvec_out x nvec_in subtiles for each thread to
  // load/store
  constexpr size_t num_iterations = THREADS_PER_WARP / warps_per_tile;
Przemek Tredak's avatar
Przemek Tredak committed
134

Tim Moon's avatar
Tim Moon committed
135
136
137
138
  // Load input and store to registers
  // Note: Each thread loads num_iterations subtiles and transposes in
  // registers.
  OVec local_output[nvec_in][num_iterations];
139
#pragma unroll
Tim Moon's avatar
Tim Moon committed
140
141
142
  for (size_t iter = 0; iter < num_iterations; ++iter) {
    const size_t i1 = tidy + iter * bdimy;
    const size_t j1 = tidx;
143
#pragma unroll
Tim Moon's avatar
Tim Moon committed
144
145
146
147
148
149
    for (size_t i2 = 0; i2 < nvec_out; ++i2) {
      const size_t row = tile_row + i1 * nvec_out + i2;
      const size_t col = tile_col + j1 * nvec_in;
      IVec local_input;
      local_input.clear();
      if (row < num_rows) {
150
#pragma unroll
Tim Moon's avatar
Tim Moon committed
151
152
153
154
        for (size_t j2 = 0; j2 < nvec_in; ++j2) {
          if (col + j2 < row_length) {
            local_input.data.elt[j2] = input[row * row_length + col + j2];
          }
Przemek Tredak's avatar
Przemek Tredak committed
155
156
        }
      }
157
#pragma unroll
Tim Moon's avatar
Tim Moon committed
158
159
160
      for (size_t j2 = 0; j2 < nvec_in; ++j2) {
        local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2];
      }
Przemek Tredak's avatar
Przemek Tredak committed
161
162
163
    }
  }

Tim Moon's avatar
Tim Moon committed
164
  // Copy transposed output from registers to global memory
165
166
  __shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll
Tim Moon's avatar
Tim Moon committed
167
  for (size_t j2 = 0; j2 < nvec_in; ++j2) {
168
#pragma unroll
Tim Moon's avatar
Tim Moon committed
169
170
171
172
    for (size_t iter = 0; iter < num_iterations; ++iter) {
      const size_t i1 = tidy + iter * bdimy;
      const size_t j1 = tidx;
      shared_output[j1][i1] = local_output[j2][iter];
Przemek Tredak's avatar
Przemek Tredak committed
173
174
    }
    __syncthreads();
175
#pragma unroll
Tim Moon's avatar
Tim Moon committed
176
177
178
179
180
181
    for (size_t iter = 0; iter < num_iterations; ++iter) {
      const size_t i1 = tidx;
      const size_t j1 = tidy + iter * bdimy;
      const size_t row = tile_row + i1 * nvec_out;
      const size_t col = tile_col + j1 * nvec_in + j2;
      if (col < row_length) {
182
#pragma unroll
Tim Moon's avatar
Tim Moon committed
183
184
185
186
187
        for (size_t i2 = 0; i2 < nvec_out; ++i2) {
          if (row + i2 < num_rows) {
            output[col * num_rows + row + i2] = shared_output[j1][i1].data.elt[i2];
          }
        }
Przemek Tredak's avatar
Przemek Tredak committed
188
189
190
191
192
193
      }
    }
    __syncthreads();
  }
}

194
195
}  // namespace

196
void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream) {
Tim Moon's avatar
Tim Moon committed
197
  Tensor &output = *output_;
198
  NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
Tim Moon's avatar
Tim Moon committed
199
  NVTE_CHECK(output.data.shape.size() == 2, "Output must have 2 dimensions.");
200
201
  const size_t row_length = input.data.shape[1];
  const size_t num_rows = input.data.shape[0];
Przemek Tredak's avatar
Przemek Tredak committed
202

Tim Moon's avatar
Tim Moon committed
203
204
  NVTE_CHECK(output.data.shape[0] == row_length, "Wrong dimension of output.");
  NVTE_CHECK(output.data.shape[1] == num_rows, "Wrong dimension of output.");
Przemek Tredak's avatar
Przemek Tredak committed
205

206
  NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
Tim Moon's avatar
Tim Moon committed
207
  NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated.");
208
209
  NVTE_CHECK(input.data.dtype == output.data.dtype, "Input (dtype=", to_string(input.data.dtype),
             ") and output (dtype=", to_string(output.data.dtype), ") do not match.");
Przemek Tredak's avatar
Przemek Tredak committed
210

211
  if (noop.data.dptr != nullptr) {
212
    NVTE_CHECK(noop.numel() == 1, "Expected 1 element, ", "but found ", noop.numel(), ".");
213
214
215
216
    NVTE_CHECK(noop.data.dtype == DType::kFloat32);
    NVTE_CHECK(noop.data.dptr != nullptr);
  }

217
218
219
220
221
222
223
224
225
226
  TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
      input.data.dtype, Type, constexpr const char *type_name = TypeInfo<Type>::name;
      constexpr size_t type_size = sizeof(Type);

      // Choose between runtime-compiled or statically-compiled kernel
      const bool aligned = (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0);
      if (aligned && rtc::is_enabled()) {  // Runtime-compiled tuned kernel
        // Pick kernel config
        std::vector<KernelConfig> kernel_configs;
        kernel_configs.reserve(16);
227
        const size_t sm_count = static_cast<size_t>(cuda::sm_count());
228
        auto add_config = [&](size_t load_size, size_t store_size) {
229
230
          kernel_configs.emplace_back(row_length, num_rows, type_size, load_size, store_size,
                                      sm_count);
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        };
        add_config(8, 8);
        add_config(4, 8);
        add_config(8, 4);
        add_config(4, 4);
        add_config(2, 8);
        add_config(8, 2);
        add_config(2, 4);
        add_config(4, 2);
        add_config(2, 2);
        add_config(1, 8);
        add_config(8, 1);
        add_config(1, 4);
        add_config(4, 1);
        add_config(1, 2);
        add_config(2, 1);
        add_config(1, 1);
        const auto &kernel_config = *std::min_element(kernel_configs.begin(), kernel_configs.end());
        NVTE_CHECK(kernel_config.valid, "invalid kernel config");
        const size_t load_size = kernel_config.load_size;
        const size_t store_size = kernel_config.store_size;
        const size_t num_blocks = kernel_config.num_blocks;

        // Compile NVRTC kernel if needed and launch
        auto &rtc_manager = rtc::KernelManager::instance();
        const std::string kernel_label = concat_strings(
            "transpose"
            ",type=",
            type_name, ",load_size=", load_size, ",store_size=", store_size);
        if (!rtc_manager.is_compiled(kernel_label)) {
          std::string code = string_code_transpose_rtc_transpose_cu;
          code = regex_replace(code, "__TYPE__", type_name);
          code = regex_replace(code, "__LOAD_SIZE__", load_size);
          code = regex_replace(code, "__STORE_SIZE__", store_size);
          code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile);
          code = regex_replace(code, "__BLOCK_SIZE__", block_size);
          rtc_manager.compile(kernel_label, "transpose_optimized_kernel", code,
                              "transformer_engine/common/transpose/rtc/transpose.cu");
        }
        rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream,
                           static_cast<const Type *>(input.data.dptr),
                           static_cast<const fp32 *>(noop.data.dptr),
                           static_cast<Type *>(output.data.dptr), row_length, num_rows);
      } else {  // Statically-compiled general kernel
        constexpr size_t load_size = 4;
        constexpr size_t store_size = 4;
        constexpr size_t row_tile_size = load_size / type_size * THREADS_PER_WARP;
        constexpr size_t col_tile_size = store_size / type_size * THREADS_PER_WARP;
        const int num_blocks = (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size));
        transpose_general_kernel<load_size, store_size, Type>
            <<<num_blocks, block_size, 0, stream>>>(static_cast<const Type *>(input.data.dptr),
                                                    static_cast<const fp32 *>(noop.data.dptr),
                                                    static_cast<Type *>(output.data.dptr),
                                                    row_length, num_rows);
285
        NVTE_CHECK_CUDA(cudaGetLastError());
286
      });  // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
287
288
}

289
}  // namespace detail
Przemek Tredak's avatar
Przemek Tredak committed
290
291
}  // namespace transformer_engine

292
void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
293
  NVTE_API_CALL(nvte_transpose);
Przemek Tredak's avatar
Przemek Tredak committed
294
  using namespace transformer_engine;
295
  auto noop = Tensor();
296
  detail::transpose(*convertNVTETensorCheck(input), noop, convertNVTETensor(output), stream);
297
298
}

299
void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
300
301
302
                              cudaStream_t stream) {
  NVTE_API_CALL(nvte_transpose_with_noop);
  using namespace transformer_engine;
303
304
  detail::transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(noop),
                    convertNVTETensor(output), stream);
Przemek Tredak's avatar
Przemek Tredak committed
305
}