Unverified Commit 14c1ecd0 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

NVRTC kernels for cast-transpose (#258)



* Add NVRTC kernels for cast-transpose
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update copyright year
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add noop flag to NVRTC cast-transpose kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent c63766d4
......@@ -81,7 +81,10 @@ std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
{65536, 128},
{256, 256},
{120, 2080},
{8, 8}};
{8, 8},
{1, 3221}, // Prime 456
{2333, 1}, // Prime 345
{1481, 677}}; // Primes 234, 123
} // namespace
class CTTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
......
......@@ -77,10 +77,12 @@ endfunction()
list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path)
make_string_header("${cuda_include_path}"
string_path_cuda_include)
make_string_header_from_file(utils.cuh
string_code_utils_cuh)
make_string_header_from_file(transpose/rtc/cast_transpose.cu
string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.cu
string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(utils.cuh
string_code_utils_cuh)
target_include_directories(transformer_engine PRIVATE
"${CMAKE_CURRENT_BINARY_DIR}/string_headers")
......
......@@ -6,432 +6,360 @@
#include <transformer_engine/cast_transpose_noop.h>
#include <transformer_engine/transpose.h>
#include <algorithm>
#include <cuda_runtime.h>
#include <iostream>
#include <cfloat>
#include "../utils.cuh"
#include "../common.h"
#include "../util/rtc.h"
#include "../util/string.h"
#include "../utils.cuh"
namespace transformer_engine {
template <bool full_tile, int nvec_in, int nvec_out, typename IVec, typename OVec, typename CType>
inline __device__ void cast_and_transpose_regs(const IVec (&in)[nvec_out],
OVec (&out_trans)[nvec_in],
typename OVec::type *output_cast_tile,
const size_t current_place,
const size_t stride,
CType &max, // NOLINT(*)
const CType scale,
const bool valid_store) {
using T = typename OVec::type;
using OVecC = Vec<T, nvec_in>;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
OVecC out_cast;
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
const CType tmp = static_cast<CType>(in[i].data.elt[j]);
const T elt_o = T(scale * tmp);
out_cast.data.elt[j] = elt_o;
out_trans[j].data.elt[i] = elt_o; // thread tile transpose
__builtin_assume(max >= 0);
max = fmaxf(fabsf(tmp), max);
}
if (full_tile || valid_store) {
out_cast.store_to(output_cast_tile, current_place + stride * i);
namespace {
// String with RTC kernel implementation
#include "string_code_transpose_rtc_cast_transpose_cu.h"
// Hard-coded kernel parameters
using CType = float;
constexpr size_t warps_per_tile = 4;
constexpr size_t block_size = THREADS_PER_WARP * warps_per_tile;
/* Performance heuristics for optimized kernel parameters */
struct KernelConfig {
/** Vector load size */
size_t load_size = 0;
/** Vector store size to transposed output */
size_t store_size = 0;
/* Whether config is valid */
bool valid = false;
/* Number of CUDA blocks */
size_t num_blocks = 0;
/* Number of active SMs */
size_t active_sm_count = 0;
/* Elements per L1 cache load */
size_t elements_per_load = 0;
/* Elements per L1 cache store to cast output*/
size_t elements_per_store_c = 0;
/* Elements per L1 cache store to transposed output */
size_t elements_per_store_t = 0;
KernelConfig(size_t row_length,
size_t num_rows,
size_t itype_size,
size_t otype_size,
size_t load_size_,
size_t store_size_)
: load_size{load_size_}
, store_size{store_size_} {
// Check that tiles are correctly aligned
constexpr size_t cache_line_size = 128;
if (load_size % itype_size != 0
|| store_size % otype_size != 0
|| cache_line_size % itype_size != 0
|| cache_line_size % otype_size != 0) {
return;
}
const size_t row_tile_elements = load_size * THREADS_PER_WARP / itype_size;
const size_t col_tile_elements = store_size * THREADS_PER_WARP / otype_size;
valid = (row_length % row_tile_elements == 0
&& num_rows % col_tile_elements == 0);
if (!valid) {
return;
}
}
// STUFF TO TUNE
constexpr unsigned int n_warps_per_tile = 4;
constexpr unsigned int max_threads_per_block = 256;
static_assert(n_warps_per_tile * THREADS_PER_WARP <= max_threads_per_block);
constexpr unsigned int cast_transpose_num_threads = n_warps_per_tile * THREADS_PER_WARP;
template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_kernel(const IType * const input,
const CType * const noop,
OType * const output_c,
OType * const output_t,
const CType * const scale_ptr,
CType * const amax,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
if (noop != nullptr && noop[0] == 1.0f) return;
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_c_tile = output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
OVec * const my_scratch = reinterpret_cast<OVec*>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
cast_and_transpose_regs<true>(in[current_in ^ 1], out_trans, my_output_c_tile,
current_place, stride, max, scale, true);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
// Number of CUDA blocks
num_blocks = (row_length / row_tile_elements) * (num_rows / col_tile_elements);
// Parameters for performance model
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm),
static_cast<size_t>(cuda::sm_count()));
elements_per_load = (std::min(cache_line_size, row_tile_elements * itype_size)
/ itype_size);
elements_per_store_c = (std::min(cache_line_size, row_tile_elements * otype_size)
/ otype_size);
elements_per_store_t = (std::min(cache_line_size, col_tile_elements * otype_size)
/ otype_size);
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
/* Compare by estimated cost */
bool operator<(const KernelConfig &other) const {
if (this->valid && other.valid) {
// cost ~ (1/elements_per_load
// + 1/elements_per_store_c
// + 1/elements_per_store_t) / active_sms
// Note: Integer arithmetic ensures stable ordering
const auto &l1 = this->elements_per_load;
const auto &sc1 = this->elements_per_store_c;
const auto &st1 = this->elements_per_store_t;
const auto &p1 = this->active_sm_count;
const auto &l2 = other.elements_per_load;
const auto &sc2 = other.elements_per_store_c;
const auto &st2 = other.elements_per_store_t;
const auto &p2 = other.active_sm_count;
const auto scale = l1 * sc1 * st1 * p1 * l2 * sc2 * st2 * p2;
const auto cost1 = (scale/l1 + scale/sc1 + scale/st1) / p1;
const auto cost2 = (scale/l2 + scale/sc2 + scale/st2) / p2;
return cost1 < cost2;
} else {
return this->valid && !other.valid;
}
__syncthreads();
}
};
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) atomicMaxFloat(amax, max);
}
}
template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType>
template <size_t load_size, size_t store_size, typename IType, typename OType>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_kernel_notaligned(const IType * const input,
const CType * const noop,
OType * const output_c,
OType * const output_t,
const CType * const scale_ptr,
CType * const amax,
__launch_bounds__(block_size)
cast_transpose_general_kernel(const IType * __restrict__ const input,
const CType * __restrict__ const noop,
OType * __restrict__ const output_c,
OType * __restrict__ const output_t,
const CType * __restrict__ const scale_ptr,
CType * __restrict__ const amax_ptr,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes
constexpr size_t nvec_in = load_size / sizeof(IType);
constexpr size_t nvec_out = store_size / sizeof(OType);
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) /
(nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_c_tile = output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_length_rest;
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec*>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
{
const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
} else {
in[0][i].clear();
}
}
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
const bool valid_load = my_place_in < tile_length &&
warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
} else {
in[current_in][j].clear();
using OVecT = Vec<OType, nvec_out>;
// Thread indices
// Note: Block is interpreted as a warp_size x num_warps grid
constexpr size_t bdimx = THREADS_PER_WARP;
constexpr size_t bdimy = warps_per_tile;
const size_t tid = threadIdx.x;
const size_t tidx = tid % bdimx;
const size_t tidy = tid / bdimx;
const size_t bid = blockIdx.x;
// Input tensors are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles
constexpr size_t tile_dim_m = THREADS_PER_WARP * nvec_out;
constexpr size_t tile_dim_n = THREADS_PER_WARP * nvec_in;
// Position of tile within tensor
const size_t num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m;
const size_t tile_id_m = bid % num_tiles_m;
const size_t tile_id_n = bid / num_tiles_m;
const size_t tile_row = tile_id_m * tile_dim_m;
const size_t tile_col = tile_id_n * tile_dim_n;
// Number of nvec_out x nvec_in subtiles for each thread to
// load/store
constexpr size_t num_iterations = THREADS_PER_WARP / warps_per_tile;
// FP8 factors
const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr;
CType amax = 0;
// Load input and store to registers
// Note: Each thread loads num_iterations subtiles, computes amax,
// casts type, and transposes in registers.
OVecT local_output_t[nvec_in][num_iterations];
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
#pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in;
if (row < num_rows) {
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
if (col + j2 < row_length) {
const CType in = input[row * row_length + col + j2];
const OType out = OType(in * scale);
__builtin_assume(amax >= 0);
amax = fmaxf(fabsf(in), amax);
output_c[row * row_length + col + j2] = out;
local_output_t[j2][iter].data.elt[i2] = out;
}
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
const bool valid_store = my_place < tile_length &&
warp_id_in_tile * n_iterations + i < tile_height;
cast_and_transpose_regs<false>(in[current_in ^ 1], out_trans, my_output_c_tile,
current_place, stride, max, scale, valid_store);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
// Copy transposed output from registers to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1];
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
shared_output_t[j1][i1] = local_output_t[j2][iter];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy;
const size_t row = tile_row + i1 * nvec_out;
const size_t col = tile_col + j1 * nvec_in + j2;
if (col < row_length) {
#pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) {
if (row + i2 < num_rows) {
output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2];
}
}
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
// Reduce amax over block
if (amax_ptr != nullptr) {
amax = reduce_max<warps_per_tile>(amax, tidy);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) atomicMaxFloat(amax, max);
atomicMaxFloat(amax_ptr, amax);
}
}
}
} // namespace
void cast_transpose(const Tensor &input,
const Tensor &noop,
Tensor *cast_output,
Tensor *transposed_output,
Tensor *cast_output_,
Tensor *transposed_output_,
cudaStream_t stream) {
CheckInputTensor(input, "cast_transpose_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
// Number of elements in tensor
auto numel = [] (const Tensor &tensor) -> size_t {
size_t acc = 1;
for (const auto& dim : tensor.data.shape) {
acc *= dim;
}
return acc;
};
Tensor &cast_output = *cast_output_;
Tensor &transposed_output = *transposed_output_;
// Check no-op flag
if (noop.data.dptr != nullptr) {
NVTE_CHECK(numel(noop) == 1,
"Expected 1 element, ",
"but found ", numel(noop), ".");
size_t numel = 1;
for (const auto& dim : noop.data.shape) {
numel *= dim;
}
NVTE_CHECK(numel == 1, "Expected 1 element, but found ", numel, ".");
NVTE_CHECK(noop.data.dtype == DType::kFloat32);
NVTE_CHECK(noop.data.dptr != nullptr);
}
// Check tensor dims
CheckInputTensor(input, "cast_transpose_input");
CheckOutputTensor(cast_output, "cast_output");
CheckOutputTensor(transposed_output, "transposed_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(input.data.shape == cast_output->data.shape,
"Input and C output must have the same shape.");
NVTE_CHECK(cast_output.data.shape.size() == 2, "Cast output must have 2 dimensions.");
NVTE_CHECK(transposed_output.data.shape.size() == 2,
"Transposed output must have 2 dimensions.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
// Launch specific cast-transpose kernel
#define LAUNCH_KERNEL(kernel, nvec_in, nvec_out, n_tiles, n_blocks, InputType, OutputType) \
do { \
cudaFuncSetAttribute(kernel<nvec_in, nvec_out, fp32, InputType, OutputType>, \
cudaFuncAttributePreferredSharedMemoryCarveout, \
100); \
kernel<nvec_in, nvec_out, fp32, InputType, OutputType> \
<<<n_blocks, \
cast_transpose_num_threads, \
cast_transpose_num_threads / n_warps_per_tile * \
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>), \
stream>>>( \
reinterpret_cast<const InputType *>(input.data.dptr), \
reinterpret_cast<const fp32 *>(noop.data.dptr), \
reinterpret_cast<OutputType *>(cast_output->data.dptr), \
reinterpret_cast<OutputType *>(transposed_output->data.dptr), \
reinterpret_cast<const fp32 *>(cast_output->scale.dptr), \
reinterpret_cast<fp32 *>(cast_output->amax.dptr), \
row_length, num_rows, n_tiles); \
} while (false)
// Launch cast-transpose kernel for given vector sizes
#define LAUNCH_KERNEL_VEC_SIZES(load_size, store_size, InputType, OutputType) \
do { \
constexpr int nvec_in = load_size / sizeof(InputType); \
constexpr int nvec_out = store_size / sizeof(OutputType); \
\
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); \
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); \
\
const size_t n_tiles = get_n_tiles(load_size, store_size); \
const size_t n_blocks = get_n_blocks(n_tiles); \
\
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && \
num_rows % (nvec_out * THREADS_PER_WARP) == 0; \
\
if (full_tile) { \
LAUNCH_KERNEL(cast_transpose_kernel, \
nvec_in, nvec_out, n_tiles, n_blocks, \
InputType, OutputType); \
} else { \
LAUNCH_KERNEL(cast_transpose_kernel_notaligned, \
nvec_in, nvec_out, n_tiles, n_blocks, \
InputType, OutputType); \
} \
} while (false)
NVTE_CHECK(cast_output.data.shape[0] == num_rows, "Wrong dimension of cast output.");
NVTE_CHECK(cast_output.data.shape[1] == row_length, "Wrong dimension of cast output.");
NVTE_CHECK(transposed_output.data.shape[0] == row_length,
"Wrong dimension of transposed output.");
NVTE_CHECK(transposed_output.data.shape[1] == num_rows,
"Wrong dimension of transposed output.");
// Check tensor pointers
NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(cast_output.data.dptr != nullptr, "Cast output is not allocated.");
NVTE_CHECK(transposed_output.data.dptr != nullptr, "Transposed output is not allocated.");
NVTE_CHECK(cast_output.data.dtype == transposed_output.data.dtype,
"Cast and transposed output types must match.");
NVTE_CHECK(cast_output.amax.dptr == transposed_output.amax.dptr,
"Cast and transposed outputs need to share amax tensor.");
NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr,
"Cast and transposed outputs need to share scale tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
// Estimate number of SMs
// Note: H100 has 132 SMs, A100 has 108 SMs.
// Note: Directly querying number of SMs with cudaGetDeviceProperties is
// slow (>1 ms). Consider querying once and caching.
const int n_sms = 128;
// Helper functions to get kernel configuration
auto get_n_tiles = [=] (size_t load_size, size_t store_size) -> int {
constexpr size_t threads_per_warp = static_cast<size_t>(THREADS_PER_WARP);
size_t nvec_in = load_size / sizeof(InputType);
size_t nvec_out = store_size / sizeof(OutputType);
size_t n_tiles = DIVUP(row_length, nvec_in * threads_per_warp) *
DIVUP(num_rows, nvec_out * threads_per_warp);
return n_tiles;
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output.data.dtype, OutputType,
constexpr const char *itype_name = TypeInfo<InputType>::name;
constexpr const char *otype_name = TypeInfo<OutputType>::name;
constexpr size_t itype_size = sizeof(InputType);
constexpr size_t otype_size = sizeof(OutputType);
// Choose between runtime-compiled or statically-compiled kernel
const bool aligned = (row_length % THREADS_PER_WARP == 0
&& num_rows % THREADS_PER_WARP == 0);
if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel
// Pick kernel config
std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16);
auto add_config = [&](size_t load_size, size_t store_size) {
kernel_configs.emplace_back(row_length, num_rows,
itype_size, otype_size,
load_size, store_size);
};
auto get_n_blocks = [=] (size_t n_tiles) -> int {
size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
return n_blocks;
};
// Estimate optimal vector sizes and run
// Note: Consider reducing to 2B or 1B loads/stores for
// sufficiently small matrices. Need to consider whether reduced
// cache efficiency is worth increased SM utilization. Also need
// to keep in mind whether datatype can fit.
const size_t estimated_n_tiles = get_n_tiles(8, 8);
const size_t estimated_n_blocks = get_n_blocks(estimated_n_tiles);
if (estimated_n_blocks >= n_sms) {
LAUNCH_KERNEL_VEC_SIZES(8, 8, InputType, OutputType);
} else {
LAUNCH_KERNEL_VEC_SIZES(4, 4, InputType, OutputType);
add_config(8, 8);
add_config(4, 8); add_config(8, 4);
add_config(4, 4);
add_config(2, 8); add_config(8, 2);
add_config(2, 4); add_config(4, 2);
add_config(2, 2);
add_config(1, 8); add_config(8, 1);
add_config(1, 4); add_config(4, 1);
add_config(1, 2); add_config(2, 1);
add_config(1, 1);
const auto &kernel_config = *std::min_element(kernel_configs.begin(),
kernel_configs.end());
NVTE_CHECK(kernel_config.valid, "invalid kernel config");
const size_t load_size = kernel_config.load_size;
const size_t store_size = kernel_config.store_size;
const size_t num_blocks = kernel_config.num_blocks;
// Compile NVRTC kernel if needed and launch
auto& rtc_manager = rtc::KernelManager::instance();
const std::string kernel_label = concat_strings("cast_transpose"
",itype=", itype_name,
",otype=", otype_name,
",load_size=", load_size,
",store_size=", store_size);
if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_cast_transpose_cu;
code = regex_replace(code, "__ITYPE__", itype_name);
code = regex_replace(code, "__OTYPE__", otype_name);
code = regex_replace(code, "__LOAD_SIZE__", load_size);
code = regex_replace(code, "__STORE_SIZE__", store_size);
code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile);
code = regex_replace(code, "__BLOCK_SIZE__", block_size);
rtc_manager.compile(kernel_label,
"cast_transpose_optimized_kernel",
code,
"transformer_engine/common/transpose/rtc/cast_transpose.cu");
}
rtc_manager.launch(kernel_label,
num_blocks, block_size, 0, stream,
static_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const CType *>(noop.data.dptr),
static_cast<OutputType*>(cast_output.data.dptr),
static_cast<OutputType*>(transposed_output.data.dptr),
static_cast<const CType*>(cast_output.scale.dptr),
static_cast<CType*>(cast_output.amax.dptr),
row_length, num_rows);
} else { // Statically-compiled general kernel
constexpr size_t load_size = 4;
constexpr size_t store_size = 4;
constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP;
constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP;
const int num_blocks = (DIVUP(row_length, row_tile_size)
* DIVUP(num_rows, col_tile_size));
cast_transpose_general_kernel<load_size, store_size, InputType, OutputType>
<<<num_blocks, block_size, 0, stream>>>(
static_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const CType *>(noop.data.dptr),
static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType *>(cast_output.amax.dptr),
row_length, num_rows);
}
); // NOLINT(*)
); // NOLINT(*)
#undef LAUNCH_KERNEL
#undef LAUNCH_KERNEL_VEC_SIZES
}
} // namespace transformer_engine
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "utils.cuh"
using namespace transformer_engine;
namespace {
// Parameters
using CType = float;
using IType = __ITYPE__;
using OType = __OTYPE__;
constexpr size_t load_size = __LOAD_SIZE__;
constexpr size_t store_size = __STORE_SIZE__;
constexpr size_t warps_per_tile = __WARPS_PER_TILE__;
constexpr size_t block_size = __BLOCK_SIZE__;
} // namespace
__global__ void
__launch_bounds__(block_size)
cast_transpose_optimized_kernel(const IType * __restrict__ const input,
const CType * __restrict__ const noop,
OType * __restrict__ const output_c,
OType * __restrict__ const output_t,
const CType * __restrict__ const scale_ptr,
CType * __restrict__ const amax_ptr,
const size_t row_length,
const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes
constexpr size_t nvec_in = load_size / sizeof(IType);
constexpr size_t nvec_out = store_size / sizeof(OType);
using IVec = Vec<IType, nvec_in>;
using OVecC = Vec<OType, nvec_in>;
using OVecT = Vec<OType, nvec_out>;
// Thread indices
// Note: Block is interpreted as a warp_size x num_warps grid
constexpr size_t bdimx = THREADS_PER_WARP;
constexpr size_t bdimy = warps_per_tile;
const size_t tid = threadIdx.x;
const size_t tidx = tid % bdimx;
const size_t tidy = tid / bdimx;
const size_t bid = blockIdx.x;
// Input tensors are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles
constexpr size_t tile_dim_m = THREADS_PER_WARP * nvec_out;
constexpr size_t tile_dim_n = THREADS_PER_WARP * nvec_in;
// Position of tile within tensor
const size_t num_tiles_m = num_rows / tile_dim_m;
const size_t tile_id_m = bid % num_tiles_m;
const size_t tile_id_n = bid / num_tiles_m;
const size_t tile_row = tile_id_m * tile_dim_m;
const size_t tile_col = tile_id_n * tile_dim_n;
// Number of nvec_out x nvec_in subtiles for each thread to
// load/store
constexpr size_t num_iterations = THREADS_PER_WARP / warps_per_tile;
// FP8 factors
const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr;
CType amax = 0;
// Load input to registers and transpose
// Note: Each thread loads num_iterations subtiles, computes amax,
// casts type, and transposes in registers.
OVecT local_output_t[nvec_in][num_iterations];
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
#pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in;
IVec local_input;
OVecC local_output_c;
local_input.load_from(&input[row * row_length + col]);
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
const CType in = static_cast<CType>(local_input.data.elt[j2]);
const OType out = OType(in * scale);
__builtin_assume(amax >= 0);
amax = fmaxf(fabsf(in), amax);
local_output_c.data.elt[j2] = out;
local_output_t[j2][iter].data.elt[i2] = out;
}
local_output_c.store_to(&output_c[row * row_length + col]);
}
}
// Copy from registers to shared memory to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1];
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
shared_output_t[j1][i1] = local_output_t[j2][iter];
}
__syncthreads();
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy;
const size_t row = tile_row + i1 * nvec_out;
const size_t col = tile_col + j1 * nvec_in + j2;
shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]);
}
__syncthreads();
}
// Reduce amax over block
if (amax_ptr != nullptr) {
amax = reduce_max<warps_per_tile>(amax, tidy);
if (threadIdx.x == 0) {
atomicMaxFloat(amax_ptr, amax);
}
}
}
......@@ -6,13 +6,15 @@
#include <transformer_engine/cast_transpose_noop.h>
#include <transformer_engine/transpose.h>
#include <algorithm>
#include <cuda_runtime.h>
#include <iostream>
#include <cfloat>
#include "../common.h"
#include "../utils.cuh"
#include "../util/string.h"
#include "../util/rtc.h"
#include "../util/string.h"
#include "../utils.cuh"
namespace transformer_engine {
......@@ -25,7 +27,80 @@ namespace {
constexpr size_t warps_per_tile = 4;
constexpr size_t block_size = THREADS_PER_WARP * warps_per_tile;
} // namespace
/* Performance heuristics for optimized kernel parameters */
struct KernelConfig {
/** Vector load size */
size_t load_size;
/** Vector store size */
size_t store_size;
/* Whether config is valid */
bool valid = false;
/* Number of CUDA blocks */
size_t num_blocks = 0;
/* Number of active SMs */
size_t active_sm_count = 0;
/* Elements per L1 cache load */
size_t elements_per_load = 0;
/* Elements per L1 cache store */
size_t elements_per_store = 0;
KernelConfig(size_t row_length,
size_t num_rows,
size_t type_size,
size_t load_size_,
size_t store_size_)
: load_size{load_size_}
, store_size{store_size_} {
// Check that tiles are correctly aligned
constexpr size_t cache_line_size = 128;
if (load_size % type_size != 0
|| store_size % type_size != 0
|| cache_line_size % type_size != 0) {
return;
}
const size_t row_tile_elements = load_size * THREADS_PER_WARP / type_size;
const size_t col_tile_elements = store_size * THREADS_PER_WARP / type_size;
valid = (row_length % row_tile_elements == 0
&& num_rows % col_tile_elements == 0);
if (!valid) {
return;
}
// Number of CUDA blocks
num_blocks = (row_length / row_tile_elements) * (num_rows / col_tile_elements);
// Parameters for performance model
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm),
static_cast<size_t>(cuda::sm_count()));
elements_per_load = (std::min(cache_line_size, row_tile_elements * type_size)
/ type_size);
elements_per_store = (std::min(cache_line_size, col_tile_elements * type_size)
/ type_size);
}
/* Compare by estimated cost */
bool operator<(const KernelConfig &other) const {
if (this->valid && other.valid) {
// cost ~ (1/elements_per_load + 1/elements_per_store) / active_sms
// Note: Integer arithmetic ensures stable ordering
const auto &l1 = this->elements_per_load;
const auto &s1 = this->elements_per_store;
const auto &p1 = this->active_sm_count;
const auto &l2 = other.elements_per_load;
const auto &s2 = other.elements_per_store;
const auto &p2 = other.active_sm_count;
const auto scale = l1 * s1 * p1 * l2 * s2 * p2;
const auto cost1 = (scale/l1 + scale/s1) / p1;
const auto cost2 = (scale/l2 + scale/s2) / p2;
return cost1 < cost2;
} else {
return this->valid && !other.valid;
}
}
};
template <size_t load_size, size_t store_size, typename Type>
__global__ void
......@@ -127,6 +202,8 @@ transpose_general_kernel(const Type * __restrict__ const input,
}
}
} // namespace
void transpose(const Tensor &input,
const Tensor &noop,
Tensor *output_,
......@@ -170,82 +247,36 @@ void transpose(const Tensor &input,
const bool aligned = (row_length % THREADS_PER_WARP == 0
&& num_rows % THREADS_PER_WARP == 0);
if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel
// Determine kernel config
size_t load_size = 8;
size_t store_size = 8;
auto is_tile_aligned = [&](size_t load_size_, size_t store_size_) -> bool {
return (row_length % (load_size / type_size * THREADS_PER_WARP) == 0
&& num_rows % (store_size / type_size * THREADS_PER_WARP) == 0);
};
auto num_blocks = [&](size_t load_size_, size_t store_size_) -> int {
const size_t row_tile_size = load_size_ / type_size * THREADS_PER_WARP;
const size_t col_tile_size = store_size_ / type_size * THREADS_PER_WARP;
return (row_length / row_tile_size) * (num_rows / col_tile_size);
// Pick kernel config
std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16);
auto add_config = [&](size_t load_size, size_t store_size) {
kernel_configs.emplace_back(row_length, num_rows, type_size,
load_size, store_size);
};
do {
const int sm_count = cuda::sm_count();
// Try maximizing SM occupancy without sacrificing cache
// efficiency
// Note: 32 threads/warp access 128B L1 cache line, so 4B
// loads/stores achieve full cache efficiency
if constexpr (type_size > 4) break;
if (is_tile_aligned(load_size, store_size)
&& num_blocks(load_size, store_size) >= 4*sm_count) {
break;
}
load_size = 4; store_size = 8;
if (is_tile_aligned(load_size, store_size)
&& num_blocks(load_size, store_size) >= 4*sm_count) {
break;
}
load_size = 4; store_size = 4;
if (is_tile_aligned(load_size, store_size)
&& num_blocks(load_size, store_size) >= sm_count) {
break;
}
// Simple performance model to balance SM occupancy and cache
// efficiency
auto cost = [&](int load_size_, int store_size_) -> double {
int active_sms = std::min(sm_count, num_blocks(load_size_, store_size_));
// Amortize memory accesses over 128B L1 cache line
int elements_per_load = std::min(128, load_size_) / type_size;
int elements_per_store = std::min(128, store_size_) / type_size;
return (1.0 / elements_per_load + 1.0 / elements_per_store) / active_sms;
};
if constexpr (type_size > 2) break;
if (is_tile_aligned(load_size, store_size)
&& cost(2, 4) >= cost(load_size, store_size)) {
break;
}
load_size = 2; store_size = 4;
if (is_tile_aligned(load_size, store_size)
&& cost(2, 2) >= cost(load_size, store_size)) {
break;
}
load_size = 2; store_size = 2;
if constexpr (type_size > 1) break;
if (is_tile_aligned(load_size, store_size)
&& cost(1, 2) >= cost(load_size, store_size)) {
break;
}
load_size = 1; store_size = 2;
if (is_tile_aligned(load_size, store_size)
&& cost(1, 1) >= cost(load_size, store_size)) {
break;
}
load_size = 1; store_size = 1;
} while (false);
NVTE_CHECK(is_tile_aligned(load_size, store_size),
"memory accesses are not properly aligned");
add_config(8, 8);
add_config(4, 8); add_config(8, 4);
add_config(4, 4);
add_config(2, 8); add_config(8, 2);
add_config(2, 4); add_config(4, 2);
add_config(2, 2);
add_config(1, 8); add_config(8, 1);
add_config(1, 4); add_config(4, 1);
add_config(1, 2); add_config(2, 1);
add_config(1, 1);
const auto &kernel_config = *std::min_element(kernel_configs.begin(),
kernel_configs.end());
NVTE_CHECK(kernel_config.valid, "invalid kernel config");
const size_t load_size = kernel_config.load_size;
const size_t store_size = kernel_config.store_size;
const size_t num_blocks = kernel_config.num_blocks;
// Compile NVRTC kernel if needed and launch
auto& rtc_manager = rtc::KernelManager::instance();
const std::string kernel_label = concat_strings("transpose"
",type=", type_name,
",load_size=", load_size,
",store_size", store_size);
",store_size=", store_size);
if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_transpose_cu;
code = regex_replace(code, "__TYPE__", type_name);
......@@ -259,7 +290,7 @@ void transpose(const Tensor &input,
"transformer_engine/common/transpose/rtc/transpose.cu");
}
rtc_manager.launch(kernel_label,
num_blocks(load_size, store_size), block_size, 0, stream,
num_blocks, block_size, 0, stream,
static_cast<const Type *>(input.data.dptr),
static_cast<const fp32 *>(noop.data.dptr),
static_cast<Type*>(output.data.dptr),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment