Unverified Commit c3f8a9f5 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[Core] Kernel that swaps first two tensor dimensions (#1998)



* Add basic kernel for swapping first two tensor dims
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add NVRTC kernel for swapping first dims
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add PyTorch extension for swap first dims kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak variable names
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tune kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Make sure writes are contiguous
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 1f2df735
...@@ -28,6 +28,7 @@ add_executable(test_operator ...@@ -28,6 +28,7 @@ add_executable(test_operator
test_multi_unpadding.cu test_multi_unpadding.cu
test_causal_softmax.cu test_causal_softmax.cu
test_swizzle.cu test_swizzle.cu
test_swap_first_dims.cu
../test_common.cu) ../test_common.cu)
find_package(OpenMP REQUIRED) find_package(OpenMP REQUIRED)
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename Type>
void compute_ref(const Type *input, Type *output,
const std::vector<size_t> &shape) {
const size_t dim0 = shape[0];
const size_t dim1 = shape[1];
size_t dim2 = 1;
for (size_t i = 2; i < shape.size(); ++i) {
dim2 *= shape[i];
}
for (size_t i = 0; i < dim0; ++i) {
for (size_t j = 0; j < dim1; ++j) {
for (size_t k = 0; k < dim2; ++k) {
const size_t in_offset = i * dim1 * dim2 + j * dim2 + k;
const size_t out_offset = j * dim0 * dim2 + i * dim2 + k;
output[out_offset] = input[in_offset];
}
}
}
}
template <typename Type>
void performTest(const std::vector<size_t> &in_shape) {
using namespace test;
DType dtype = TypeInfo<Type>::dtype;
// Tensor dimensions
std::vector<size_t> out_shape = in_shape;
out_shape[0] = in_shape[1];
out_shape[1] = in_shape[0];
size_t numel = 1;
for (const auto& dim : in_shape) {
numel *= dim;
}
// Transformer engine implementation
Tensor input("input", in_shape, dtype);
Tensor output("output", out_shape, dtype);
fillUniform(&input);
nvte_swap_first_dims(input.data(), output.data(), 0);
// Reference implementation
std::unique_ptr<Type[]> ref_output = std::make_unique<Type[]>(numel);
compute_ref<Type>(input.rowwise_cpu_dptr<Type>(), ref_output.get(), in_shape);
// Check for CUDA failure
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
// Check for exact numerics
compareResults("output", output, ref_output.get(), true, 0, 0);
}
std::vector<std::vector<size_t>> test_cases = {{4, 64, 1280},
{48, 8, 128, 16},
{229, 173}, // Primes 50, 40
{113, 71, 1, 1, 1, 29, 1, 1}}; // Primes 30, 20, 10
} // namespace
class SwapFirstDimsTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(SwapFirstDimsTestSuite, TestSwapFirstDims) {
using namespace transformer_engine;
using namespace test;
const DType type = std::get<0>(GetParam());
const auto shape = std::get<1>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
performTest<T>(shape);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
SwapFirstDimsTestSuite,
::testing::Combine(
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<SwapFirstDimsTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param));
for (const auto& dim : std::get<1>(info.param)) {
name += "X";
name += std::to_string(dim);
}
return name;
});
...@@ -67,6 +67,7 @@ list(APPEND transformer_engine_SOURCES ...@@ -67,6 +67,7 @@ list(APPEND transformer_engine_SOURCES
transpose/multi_cast_transpose.cu transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu activation/gelu.cu
fused_attn/flash_attn.cu fused_attn/flash_attn.cu
fused_attn/context_parallel.cu fused_attn/context_parallel.cu
...@@ -166,6 +167,8 @@ make_string_header_from_file(transpose/rtc/cast_transpose.cu ...@@ -166,6 +167,8 @@ make_string_header_from_file(transpose/rtc/cast_transpose.cu
string_code_transpose_rtc_cast_transpose_cu) string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.cu make_string_header_from_file(transpose/rtc/transpose.cu
string_code_transpose_rtc_transpose_cu) string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(transpose/rtc/swap_first_dims.cu
string_code_transpose_rtc_swap_first_dims_cu)
make_string_header_from_file(utils.cuh make_string_header_from_file(utils.cuh
string_code_utils_cuh) string_code_utils_cuh)
make_string_header_from_file(util/math.h make_string_header_from_file(util/math.h
......
...@@ -318,6 +318,14 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_in ...@@ -318,6 +318,14 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_in
void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
NVTETensor output, cudaStream_t stream); NVTETensor output, cudaStream_t stream);
/*! \brief Swap the first two tensor dimensions.
*
* \param[in] input Input tensor of shape [M, N, ...].
* \param[out] output Output tensor of shape [N, M, ...].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_swap_first_dims(const NVTETensor input, NVTETensor output, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "utils.cuh"
using namespace transformer_engine;
namespace {
// Parameters
using VectorType = BytesToType<__VECTOR_SIZE__>::Type;
constexpr size_t block_size = __BLOCK_SIZE__;
} // namespace
__global__ void __launch_bounds__(block_size)
swap_first_dims_kernel(const VectorType* __restrict__ const input,
VectorType* __restrict__ const output, const size_t dim0,
const size_t dim1, const size_t dim2) {
const size_t gid = threadIdx.x + blockIdx.x * block_size;
#if __SINGLE_LOAD_STORE__
const auto idx = gid;
#else
const size_t nthreads = gridDim.x * block_size;
for (size_t idx = gid; idx < dim0 * dim1 * dim2; idx += nthreads)
#endif // __SINGLE_LOAD_STORE__
{
const auto idx2 = idx % dim2;
const auto idx1 = (idx / dim2) % dim1;
const auto idx0 = (idx / dim2) / dim1;
const auto in_offset = idx1 * dim0 * dim2 + idx0 * dim2 + idx2;
output[idx] = input[in_offset];
}
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/transpose.h>
#include <algorithm>
#include <cstdint>
#include <string>
#include "../common.h"
#include "../util/cuda_runtime.h"
#include "../util/logging.h"
#include "../util/rtc.h"
#include "../util/string.h"
namespace transformer_engine {
namespace {
// String with RTC kernel implementation
#include "string_code_transpose_rtc_swap_first_dims_cu.h"
// Hard-coded kernel parameters
constexpr size_t block_size = 128;
/* Performance heuristics for optimized kernel parameters */
struct KernelConfig {
/* Vector load/store size */
size_t vector_size;
/* Whether config is valid */
bool valid = false;
/* Number of CUDA blocks */
size_t num_blocks = 0;
/* Whether each thread needs to make exactly one load/store */
bool single_load_store = true;
/* Number of active SMs */
size_t active_sm_count = 0;
/* Used bytes per L1 cache load */
size_t bytes_per_load = 0;
/* Used bytes per L1 cache store */
size_t bytes_per_store = 0;
KernelConfig(size_t dim0, size_t dim1, size_t dim2, size_t sm_count, size_t vector_size_)
: vector_size{vector_size_} {
// Check that tiles are correctly aligned
if (dim2 % vector_size_ != 0) {
return;
}
valid = true;
// Number of CUDA blocks
num_blocks = DIVUP(dim0 * dim1 * dim2 / vector_size, block_size);
if (num_blocks > 2147483647ull) {
// Maximum number of CUDA blocks
single_load_store = false;
num_blocks = 2147483647ull;
} else if (num_blocks * block_size != dim0 * dim1 * dim2 / vector_size) {
single_load_store = false;
}
// SM occupancy
constexpr size_t warp_size = 32;
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
active_sm_count = std::min(DIVUP(num_blocks * block_size / warp_size, warps_per_sm), sm_count);
// L1 cache efficiency
constexpr size_t cache_line_size = 128;
bytes_per_store = std::min(cache_line_size, warp_size * vector_size); // Contiguous writes
bytes_per_load = bytes_per_store;
if (dim2 % (vector_size * warp_size) != 0) {
// Some warps are reading from two non-contiguous regions
bytes_per_load /= 2;
}
}
/* Compare by estimated cost */
bool operator<(const KernelConfig &other) const {
if (this->valid && other.valid) {
// cost ~ (1/bytes_per_load + 1/bytes_per_store) / active_sms
// Note: Integer arithmetic ensures stable ordering
const auto &l1 = this->bytes_per_load;
const auto &s1 = this->bytes_per_store;
const auto &p1 = this->active_sm_count;
const auto &l2 = other.bytes_per_load;
const auto &s2 = other.bytes_per_store;
const auto &p2 = other.active_sm_count;
const auto scale = l1 * s1 * p1 * l2 * s2 * p2;
const auto cost1 = (scale / l1 + scale / s1) / p1;
const auto cost2 = (scale / l2 + scale / s2) / p2;
return cost1 < cost2;
} else {
return this->valid && !other.valid;
}
}
};
template <typename Type>
__global__ void __launch_bounds__(block_size)
swap_first_dims_untuned_kernel(const Type *__restrict__ input, Type *__restrict__ output,
const size_t dim0, const size_t dim1, const size_t dim2) {
const size_t gid = threadIdx.x + blockIdx.x * block_size;
const size_t nthreads = gridDim.x * block_size;
for (size_t idx = gid; idx < dim0 * dim1 * dim2; idx += nthreads) {
const auto idx2 = idx % dim2;
const auto idx1 = (idx / dim2) % dim1;
const auto idx0 = (idx / dim2) / dim1;
const auto in_offset = idx1 * dim0 * dim2 + idx0 * dim2 + idx2;
output[idx] = input[in_offset];
}
}
} // namespace
void swap_first_dims(const Tensor &input, Tensor &output, cudaStream_t stream) {
// Check tensors
NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor must be simple tensor, but scaling mode is ",
to_string(input.scaling_mode), ".");
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor must be simple tensor, but scaling mode is ",
to_string(output.scaling_mode), ".");
NVTE_CHECK(input.dtype() == output.dtype(), "Input tensor (dtype=", to_string(input.dtype()),
") and output tensor (dtype=", to_string(output.dtype()), ") do not match.");
NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated.");
// Check tensor dimensions
const auto input_shape = input.shape();
const auto output_shape = output.shape();
NVTE_CHECK(input_shape.size() >= 2, "Invalid input tensor dimensions (shape=", input_shape, ").");
NVTE_CHECK(output_shape.size() == input_shape.size(), "Input tensor (shape=", input_shape,
") and output tensor (shape=", output_shape, ") do not match.");
NVTE_CHECK(input_shape[0] == output_shape[1], "Input tensor (shape=", input_shape,
") and output tensor (shape=", output_shape, ") do not match.");
NVTE_CHECK(input_shape[1] == output_shape[0], "Input tensor (shape=", input_shape,
") and output tensor (shape=", output_shape, ") do not match.");
for (size_t i = 2; i < input_shape.size(); ++i) {
NVTE_CHECK(input_shape[i] == output_shape[i], "Input tensor (shape=", input_shape,
") and output tensor (shape=", output_shape, ") do not match.");
}
// Reinterpret tensors as 3D tensors of bytes
const size_t dim0 = output_shape[0];
const size_t dim1 = output_shape[1];
size_t dim2 = 1;
for (size_t i = 2; i < output_shape.size(); ++i) {
dim2 *= output_shape[i];
}
dim2 = get_buffer_size_bytes(dim2, output.dtype());
// Choose kernel config with performance heuristics
const size_t sm_count = static_cast<size_t>(cuda::sm_count());
KernelConfig config(dim0, dim1, dim2, sm_count, 1);
if (rtc::is_enabled()) {
auto try_config = [&](size_t vector_size) {
KernelConfig new_config(dim0, dim1, dim2, sm_count, vector_size);
if (new_config < config) {
config = new_config;
}
};
try_config(16);
try_config(8);
try_config(4);
try_config(2);
}
const size_t vector_size = config.vector_size;
// Launch kernel
if (vector_size == 1) {
// General kernel
swap_first_dims_untuned_kernel<<<config.num_blocks, block_size, 0, stream>>>(
static_cast<const uint8_t *>(input.data.dptr), static_cast<uint8_t *>(output.data.dptr),
dim0, dim1, dim2);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
// Compile NVRTC kernel if needed
auto &rtc_manager = rtc::KernelManager::instance();
const std::string kernel_label =
concat_strings("swap_first_dims,vector_size=", vector_size,
",single_load_store=", config.single_load_store);
if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_swap_first_dims_cu;
code = regex_replace(code, "__VECTOR_SIZE__", vector_size);
code = regex_replace(code, "__BLOCK_SIZE__", block_size);
code =
regex_replace(code, "__SINGLE_LOAD_STORE__", static_cast<int>(config.single_load_store));
rtc_manager.compile(kernel_label, "swap_first_dims_kernel", code,
"transformer_engine/common/transpose/rtc/swap_first_dims.cu");
}
// Launch NVRTC kernel
rtc_manager.launch(kernel_label, config.num_blocks, block_size, 0, stream, input.data.dptr,
output.data.dptr, dim0, dim1, dim2 / vector_size);
}
}
} // namespace transformer_engine
void nvte_swap_first_dims(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swap_first_dims);
using namespace transformer_engine;
swap_first_dims(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), stream);
}
...@@ -143,6 +143,8 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -143,6 +143,8 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
at::Tensor fp8_transpose(at::Tensor input, DType otype, at::Tensor fp8_transpose(at::Tensor input, DType otype,
std::optional<at::Tensor> output = std::nullopt); std::optional<at::Tensor> output = std::nullopt);
at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out = std::nullopt);
/*************************************************************************************************** /***************************************************************************************************
* Activations * Activations
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -211,6 +211,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -211,6 +211,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims,
"Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &transformer_engine::pytorch::get_fused_attn_backend, m.def("get_fused_attn_backend", &transformer_engine::pytorch::get_fused_attn_backend,
"Get Fused Attention backend", py::call_guard<py::gil_scoped_release>()); "Get Fused Attention backend", py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &transformer_engine::pytorch::compute_amax, m.def("compute_amax", &transformer_engine::pytorch::compute_amax,
......
...@@ -52,5 +52,30 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor ...@@ -52,5 +52,30 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor
return out; return out;
} }
at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out) {
init_extension();
// Make sure input is contiguous
const auto &input = tensor.contiguous();
// Allocate output tensor if needed
if (!out) {
auto in_shape = getTensorShape(input);
NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")");
std::vector<int64_t> out_shape_int64(in_shape.begin(), in_shape.end());
out_shape_int64[0] = static_cast<int64_t>(in_shape[1]);
out_shape_int64[1] = static_cast<int64_t>(in_shape[0]);
auto opts = at::TensorOptions().dtype(input.dtype()).device(input.device());
out = at::empty(out_shape_int64, opts);
}
// Launch kernel
const TensorWrapper te_input = makeTransformerEngineTensor(input);
TensorWrapper te_output = makeTransformerEngineTensor(*out);
nvte_swap_first_dims(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
return std::move(*out);
}
} // namespace pytorch } // namespace pytorch
} // namespace transformer_engine } // namespace transformer_engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment