Unverified Commit 30632f31 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Use 4B vector loads/stores in cast-transpose kernel for small matrices (#101)


Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 277b0be2
...@@ -47,8 +47,6 @@ inline __device__ void cast_and_transpose_regs(const IVec (&in)[nvec_out], ...@@ -47,8 +47,6 @@ inline __device__ void cast_and_transpose_regs(const IVec (&in)[nvec_out],
// STUFF TO TUNE // STUFF TO TUNE
constexpr unsigned int n_warps_per_tile = 4; constexpr unsigned int n_warps_per_tile = 4;
constexpr int desired_load_size = 8;
constexpr int desired_store_size = 8;
constexpr unsigned int max_threads_per_block = 256; constexpr unsigned int max_threads_per_block = 256;
static_assert(n_warps_per_tile * THREADS_PER_WARP <= max_threads_per_block); static_assert(n_warps_per_tile * THREADS_PER_WARP <= max_threads_per_block);
...@@ -321,61 +319,94 @@ void cast_transpose(const Tensor &input, ...@@ -321,61 +319,94 @@ void cast_transpose(const Tensor &input,
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor."); "C and T outputs need to share scale tensor.");
// Launch specific cast-transpose kernel
#define LAUNCH_KERNEL(kernel, nvec_in, nvec_out, n_tiles, n_blocks, InputType, OutputType) \
do { \
cudaFuncSetAttribute(kernel<nvec_in, nvec_out, fp32, InputType, OutputType>, \
cudaFuncAttributePreferredSharedMemoryCarveout, \
100); \
kernel<nvec_in, nvec_out, fp32, InputType, OutputType> \
<<<n_blocks, \
cast_transpose_num_threads, \
cast_transpose_num_threads / n_warps_per_tile * \
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>), \
stream>>>( \
reinterpret_cast<const InputType *>(input.data.dptr), \
reinterpret_cast<OutputType *>(cast_output->data.dptr), \
reinterpret_cast<OutputType *>(transposed_output->data.dptr), \
reinterpret_cast<const fp32 *>(cast_output->scale.dptr), \
reinterpret_cast<fp32 *>(cast_output->amax.dptr), \
row_length, num_rows, n_tiles); \
} while (false)
// Launch cast-transpose kernel for given vector sizes
#define LAUNCH_KERNEL_VEC_SIZES(load_size, store_size, InputType, OutputType) \
do { \
constexpr int nvec_in = load_size / sizeof(InputType); \
constexpr int nvec_out = store_size / sizeof(OutputType); \
\
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); \
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); \
\
const size_t n_tiles = get_n_tiles(load_size, store_size); \
const size_t n_blocks = get_n_blocks(n_tiles); \
\
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && \
num_rows % (nvec_out * THREADS_PER_WARP) == 0; \
\
if (full_tile) { \
LAUNCH_KERNEL(cast_transpose_kernel, \
nvec_in, nvec_out, n_tiles, n_blocks, \
InputType, OutputType); \
} else { \
LAUNCH_KERNEL(cast_transpose_kernel_notaligned, \
nvec_in, nvec_out, n_tiles, n_blocks, \
InputType, OutputType); \
} \
} while (false)
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType); // Estimate number of SMs
constexpr int nvec_in = desired_load_size / itype_size; // Note: H100 has 132 SMs, A100 has 108 SMs.
constexpr int nvec_out = desired_store_size / otype_size; // Note: Directly querying number of SMs with cudaGetDeviceProperties is
// slow (>1 ms). Consider querying once and caching.
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); const int n_sms = 128;
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
// Helper functions to get kernel configuration
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) * auto get_n_tiles = [=] (size_t load_size, size_t store_size) -> int {
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP)); constexpr size_t threads_per_warp = static_cast<size_t>(THREADS_PER_WARP);
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP; size_t nvec_in = load_size / sizeof(InputType);
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block); size_t nvec_out = store_size / sizeof(OutputType);
size_t n_tiles = DIVUP(row_length, nvec_in * threads_per_warp) *
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && DIVUP(num_rows, nvec_out * threads_per_warp);
num_rows % (nvec_out * THREADS_PER_WARP) == 0; return n_tiles;
};
if (full_tile) { auto get_n_blocks = [=] (size_t n_tiles) -> int {
cudaFuncSetAttribute(cast_transpose_kernel<nvec_in, nvec_out, fp32, size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
InputType, OutputType>, size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
cudaFuncAttributePreferredSharedMemoryCarveout, return n_blocks;
100); };
cast_transpose_kernel<nvec_in, nvec_out, fp32, InputType, OutputType>
<<<n_blocks, // Estimate optimal vector sizes and run
cast_transpose_num_threads, // Note: Consider reducing to 2B or 1B loads/stores for
cast_transpose_num_threads / n_warps_per_tile * // sufficiently small matrices. Need to consider whether reduced
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>), // cache efficiency is worth increased SM utilization. Also need
stream>>>( // to keep in mind whether datatype can fit.
reinterpret_cast<const InputType *>(input.data.dptr), const size_t estimated_n_tiles = get_n_tiles(8, 8);
reinterpret_cast<OutputType *>(cast_output->data.dptr), const size_t estimated_n_blocks = get_n_blocks(estimated_n_tiles);
reinterpret_cast<OutputType *>(transposed_output->data.dptr), if (estimated_n_blocks >= n_sms) {
reinterpret_cast<const fp32 *>(cast_output->scale.dptr), LAUNCH_KERNEL_VEC_SIZES(8, 8, InputType, OutputType);
reinterpret_cast<fp32 *>(cast_output->amax.dptr),
row_length, num_rows, n_tiles);
} else { } else {
cudaFuncSetAttribute(cast_transpose_kernel_notaligned<nvec_in, nvec_out, fp32, LAUNCH_KERNEL_VEC_SIZES(4, 4, InputType, OutputType);
InputType, OutputType>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_kernel_notaligned<nvec_in, nvec_out, fp32, InputType, OutputType>
<<<n_blocks,
cast_transpose_num_threads,
cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>),
stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr),
reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr),
row_length, num_rows, n_tiles);
} }
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
#undef LAUNCH_KERNEL
#undef LAUNCH_KERNEL_VEC_SIZES
} }
} // namespace transformer_engine } // namespace transformer_engine
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment