Unverified Commit 14c1ecd0 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

NVRTC kernels for cast-transpose (#258)



* Add NVRTC kernels for cast-transpose
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update copyright year
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add noop flag to NVRTC cast-transpose kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent c63766d4
......@@ -81,7 +81,10 @@ std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
{65536, 128},
{256, 256},
{120, 2080},
{8, 8}};
{8, 8},
{1, 3221}, // Prime 456
{2333, 1}, // Prime 345
{1481, 677}}; // Primes 234, 123
} // namespace
class CTTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
......
......@@ -77,10 +77,12 @@ endfunction()
list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path)
make_string_header("${cuda_include_path}"
string_path_cuda_include)
make_string_header_from_file(utils.cuh
string_code_utils_cuh)
make_string_header_from_file(transpose/rtc/cast_transpose.cu
string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.cu
string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(utils.cuh
string_code_utils_cuh)
target_include_directories(transformer_engine PRIVATE
"${CMAKE_CURRENT_BINARY_DIR}/string_headers")
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "utils.cuh"
using namespace transformer_engine;
namespace {
// Parameters
using CType = float;
using IType = __ITYPE__;
using OType = __OTYPE__;
constexpr size_t load_size = __LOAD_SIZE__;
constexpr size_t store_size = __STORE_SIZE__;
constexpr size_t warps_per_tile = __WARPS_PER_TILE__;
constexpr size_t block_size = __BLOCK_SIZE__;
} // namespace
__global__ void
__launch_bounds__(block_size)
cast_transpose_optimized_kernel(const IType * __restrict__ const input,
const CType * __restrict__ const noop,
OType * __restrict__ const output_c,
OType * __restrict__ const output_t,
const CType * __restrict__ const scale_ptr,
CType * __restrict__ const amax_ptr,
const size_t row_length,
const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes
constexpr size_t nvec_in = load_size / sizeof(IType);
constexpr size_t nvec_out = store_size / sizeof(OType);
using IVec = Vec<IType, nvec_in>;
using OVecC = Vec<OType, nvec_in>;
using OVecT = Vec<OType, nvec_out>;
// Thread indices
// Note: Block is interpreted as a warp_size x num_warps grid
constexpr size_t bdimx = THREADS_PER_WARP;
constexpr size_t bdimy = warps_per_tile;
const size_t tid = threadIdx.x;
const size_t tidx = tid % bdimx;
const size_t tidy = tid / bdimx;
const size_t bid = blockIdx.x;
// Input tensors are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles
constexpr size_t tile_dim_m = THREADS_PER_WARP * nvec_out;
constexpr size_t tile_dim_n = THREADS_PER_WARP * nvec_in;
// Position of tile within tensor
const size_t num_tiles_m = num_rows / tile_dim_m;
const size_t tile_id_m = bid % num_tiles_m;
const size_t tile_id_n = bid / num_tiles_m;
const size_t tile_row = tile_id_m * tile_dim_m;
const size_t tile_col = tile_id_n * tile_dim_n;
// Number of nvec_out x nvec_in subtiles for each thread to
// load/store
constexpr size_t num_iterations = THREADS_PER_WARP / warps_per_tile;
// FP8 factors
const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr;
CType amax = 0;
// Load input to registers and transpose
// Note: Each thread loads num_iterations subtiles, computes amax,
// casts type, and transposes in registers.
OVecT local_output_t[nvec_in][num_iterations];
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
#pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in;
IVec local_input;
OVecC local_output_c;
local_input.load_from(&input[row * row_length + col]);
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
const CType in = static_cast<CType>(local_input.data.elt[j2]);
const OType out = OType(in * scale);
__builtin_assume(amax >= 0);
amax = fmaxf(fabsf(in), amax);
local_output_c.data.elt[j2] = out;
local_output_t[j2][iter].data.elt[i2] = out;
}
local_output_c.store_to(&output_c[row * row_length + col]);
}
}
// Copy from registers to shared memory to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1];
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
shared_output_t[j1][i1] = local_output_t[j2][iter];
}
__syncthreads();
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy;
const size_t row = tile_row + i1 * nvec_out;
const size_t col = tile_col + j1 * nvec_in + j2;
shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]);
}
__syncthreads();
}
// Reduce amax over block
if (amax_ptr != nullptr) {
amax = reduce_max<warps_per_tile>(amax, tidy);
if (threadIdx.x == 0) {
atomicMaxFloat(amax_ptr, amax);
}
}
}
......@@ -6,13 +6,15 @@
#include <transformer_engine/cast_transpose_noop.h>
#include <transformer_engine/transpose.h>
#include <algorithm>
#include <cuda_runtime.h>
#include <iostream>
#include <cfloat>
#include "../common.h"
#include "../utils.cuh"
#include "../util/string.h"
#include "../util/rtc.h"
#include "../util/string.h"
#include "../utils.cuh"
namespace transformer_engine {
......@@ -25,7 +27,80 @@ namespace {
constexpr size_t warps_per_tile = 4;
constexpr size_t block_size = THREADS_PER_WARP * warps_per_tile;
} // namespace
/* Performance heuristics for optimized kernel parameters */
struct KernelConfig {
/** Vector load size */
size_t load_size;
/** Vector store size */
size_t store_size;
/* Whether config is valid */
bool valid = false;
/* Number of CUDA blocks */
size_t num_blocks = 0;
/* Number of active SMs */
size_t active_sm_count = 0;
/* Elements per L1 cache load */
size_t elements_per_load = 0;
/* Elements per L1 cache store */
size_t elements_per_store = 0;
KernelConfig(size_t row_length,
size_t num_rows,
size_t type_size,
size_t load_size_,
size_t store_size_)
: load_size{load_size_}
, store_size{store_size_} {
// Check that tiles are correctly aligned
constexpr size_t cache_line_size = 128;
if (load_size % type_size != 0
|| store_size % type_size != 0
|| cache_line_size % type_size != 0) {
return;
}
const size_t row_tile_elements = load_size * THREADS_PER_WARP / type_size;
const size_t col_tile_elements = store_size * THREADS_PER_WARP / type_size;
valid = (row_length % row_tile_elements == 0
&& num_rows % col_tile_elements == 0);
if (!valid) {
return;
}
// Number of CUDA blocks
num_blocks = (row_length / row_tile_elements) * (num_rows / col_tile_elements);
// Parameters for performance model
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm),
static_cast<size_t>(cuda::sm_count()));
elements_per_load = (std::min(cache_line_size, row_tile_elements * type_size)
/ type_size);
elements_per_store = (std::min(cache_line_size, col_tile_elements * type_size)
/ type_size);
}
/* Compare by estimated cost */
bool operator<(const KernelConfig &other) const {
if (this->valid && other.valid) {
// cost ~ (1/elements_per_load + 1/elements_per_store) / active_sms
// Note: Integer arithmetic ensures stable ordering
const auto &l1 = this->elements_per_load;
const auto &s1 = this->elements_per_store;
const auto &p1 = this->active_sm_count;
const auto &l2 = other.elements_per_load;
const auto &s2 = other.elements_per_store;
const auto &p2 = other.active_sm_count;
const auto scale = l1 * s1 * p1 * l2 * s2 * p2;
const auto cost1 = (scale/l1 + scale/s1) / p1;
const auto cost2 = (scale/l2 + scale/s2) / p2;
return cost1 < cost2;
} else {
return this->valid && !other.valid;
}
}
};
template <size_t load_size, size_t store_size, typename Type>
__global__ void
......@@ -127,6 +202,8 @@ transpose_general_kernel(const Type * __restrict__ const input,
}
}
} // namespace
void transpose(const Tensor &input,
const Tensor &noop,
Tensor *output_,
......@@ -170,82 +247,36 @@ void transpose(const Tensor &input,
const bool aligned = (row_length % THREADS_PER_WARP == 0
&& num_rows % THREADS_PER_WARP == 0);
if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel
// Determine kernel config
size_t load_size = 8;
size_t store_size = 8;
auto is_tile_aligned = [&](size_t load_size_, size_t store_size_) -> bool {
return (row_length % (load_size / type_size * THREADS_PER_WARP) == 0
&& num_rows % (store_size / type_size * THREADS_PER_WARP) == 0);
};
auto num_blocks = [&](size_t load_size_, size_t store_size_) -> int {
const size_t row_tile_size = load_size_ / type_size * THREADS_PER_WARP;
const size_t col_tile_size = store_size_ / type_size * THREADS_PER_WARP;
return (row_length / row_tile_size) * (num_rows / col_tile_size);
// Pick kernel config
std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16);
auto add_config = [&](size_t load_size, size_t store_size) {
kernel_configs.emplace_back(row_length, num_rows, type_size,
load_size, store_size);
};
do {
const int sm_count = cuda::sm_count();
// Try maximizing SM occupancy without sacrificing cache
// efficiency
// Note: 32 threads/warp access 128B L1 cache line, so 4B
// loads/stores achieve full cache efficiency
if constexpr (type_size > 4) break;
if (is_tile_aligned(load_size, store_size)
&& num_blocks(load_size, store_size) >= 4*sm_count) {
break;
}
load_size = 4; store_size = 8;
if (is_tile_aligned(load_size, store_size)
&& num_blocks(load_size, store_size) >= 4*sm_count) {
break;
}
load_size = 4; store_size = 4;
if (is_tile_aligned(load_size, store_size)
&& num_blocks(load_size, store_size) >= sm_count) {
break;
}
// Simple performance model to balance SM occupancy and cache
// efficiency
auto cost = [&](int load_size_, int store_size_) -> double {
int active_sms = std::min(sm_count, num_blocks(load_size_, store_size_));
// Amortize memory accesses over 128B L1 cache line
int elements_per_load = std::min(128, load_size_) / type_size;
int elements_per_store = std::min(128, store_size_) / type_size;
return (1.0 / elements_per_load + 1.0 / elements_per_store) / active_sms;
};
if constexpr (type_size > 2) break;
if (is_tile_aligned(load_size, store_size)
&& cost(2, 4) >= cost(load_size, store_size)) {
break;
}
load_size = 2; store_size = 4;
if (is_tile_aligned(load_size, store_size)
&& cost(2, 2) >= cost(load_size, store_size)) {
break;
}
load_size = 2; store_size = 2;
if constexpr (type_size > 1) break;
if (is_tile_aligned(load_size, store_size)
&& cost(1, 2) >= cost(load_size, store_size)) {
break;
}
load_size = 1; store_size = 2;
if (is_tile_aligned(load_size, store_size)
&& cost(1, 1) >= cost(load_size, store_size)) {
break;
}
load_size = 1; store_size = 1;
} while (false);
NVTE_CHECK(is_tile_aligned(load_size, store_size),
"memory accesses are not properly aligned");
add_config(8, 8);
add_config(4, 8); add_config(8, 4);
add_config(4, 4);
add_config(2, 8); add_config(8, 2);
add_config(2, 4); add_config(4, 2);
add_config(2, 2);
add_config(1, 8); add_config(8, 1);
add_config(1, 4); add_config(4, 1);
add_config(1, 2); add_config(2, 1);
add_config(1, 1);
const auto &kernel_config = *std::min_element(kernel_configs.begin(),
kernel_configs.end());
NVTE_CHECK(kernel_config.valid, "invalid kernel config");
const size_t load_size = kernel_config.load_size;
const size_t store_size = kernel_config.store_size;
const size_t num_blocks = kernel_config.num_blocks;
// Compile NVRTC kernel if needed and launch
auto& rtc_manager = rtc::KernelManager::instance();
const std::string kernel_label = concat_strings("transpose"
",type=", type_name,
",load_size=", load_size,
",store_size", store_size);
",store_size=", store_size);
if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_transpose_cu;
code = regex_replace(code, "__TYPE__", type_name);
......@@ -259,7 +290,7 @@ void transpose(const Tensor &input,
"transformer_engine/common/transpose/rtc/transpose.cu");
}
rtc_manager.launch(kernel_label,
num_blocks(load_size, store_size), block_size, 0, stream,
num_blocks, block_size, 0, stream,
static_cast<const Type *>(input.data.dptr),
static_cast<const fp32 *>(noop.data.dptr),
static_cast<Type*>(output.data.dptr),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment