Unverified Commit 2a3916b4 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Multi-tensor cast-transpose (#18)



* Add kernel for multi-tensor cast-transpose
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix incorrect test function in multi-tensor cast-transpose unit test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove std::vector from multi-tensor cast-transpose function signature

Makes sure the main header is C-compatible.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
parent d10dfb57
...@@ -2,14 +2,16 @@ ...@@ -2,14 +2,16 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
add_executable(test_operator test_qdq.cu add_executable(test_operator
test_cast_transpose.cu test_qdq.cu
test_transpose.cu test_cast_transpose.cu
test_cast_transpose_dbias.cu test_transpose.cu
test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dbias.cu
test_gelu.cu test_cast_transpose_dbias_dgelu.cu
test_layernorm.cu test_gelu.cu
../test_common.cu) test_layernorm.cu
test_multi_cast_transpose.cu
../test_common.cu)
target_link_libraries(test_operator PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB}) target_link_libraries(test_operator PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB})
......
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transpose.h>
#include <transformer_engine/logging.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cstring>
#include <iostream>
#include <iomanip>
#include <memory>
#include <random>
#include <vector>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref(const std::vector<std::vector<InputType>>& input_list,
std::vector<std::vector<OutputType>>& output_c_list,
std::vector<std::vector<OutputType>>& output_t_list,
const std::vector<float>& scale_list,
std::vector<float>& amax_list,
std::vector<float>& scale_inv_list,
const std::vector<size_t>& height_list,
const std::vector<size_t>& width_list) {
using compute_t = float;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = input_list[tensor_id];
auto& output_c = output_c_list[tensor_id];
auto& output_t = output_t_list[tensor_id];
const compute_t scale = scale_list[tensor_id];
compute_t& amax = amax_list[tensor_id];
compute_t& scale_inv = scale_inv_list[tensor_id];
const size_t height = height_list[tensor_id];
const size_t width = width_list[tensor_id];
scale_inv = 1. / scale;
amax = -1e100;
for (size_t i = 0; i < height; ++i) {
for (size_t j = 0; j < width; ++j) {
const compute_t x = static_cast<compute_t>(input[i * width + j]);
const OutputType y = static_cast<OutputType>(scale * x);
amax = fmaxf(amax, fabsf(x));
output_c[i * width + j] = y;
output_t[j * height + i] = y;
}
}
}
}
template <typename InputType, typename OutputType>
void performTest() {
using namespace test;
const DType itype = TypeInfo<InputType>::dtype;
const DType otype = TypeInfo<OutputType>::dtype;
const DType ctype = DType::kFloat32;
const std::vector<std::pair<size_t, size_t>> tensor_dims = {{1,1},
{1,768},
{768,1},
{768,768},
{43,43},
{43,256},
{256,43},
{256,256}};
const size_t num_tensors = tensor_dims.size();
// Buffers for Transformer Engine implementation
std::vector<Tensor> input_list, output_c_list, output_t_list,
scale_list, amax_list, scale_inv_list;
// Buffers for reference implementation
std::vector<std::vector<InputType>> ref_input_list;
std::vector<std::vector<OutputType>> ref_output_c_list, ref_output_t_list;
std::vector<float> ref_scale_list(num_tensors), ref_amax_list(num_tensors),
ref_scale_inv_list(num_tensors);
std::vector<size_t> ref_height_list(num_tensors), ref_width_list(num_tensors);
// Initialize buffers
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
const size_t height = tensor_dims[tensor_id].first;
const size_t width = tensor_dims[tensor_id].second;
input_list.emplace_back(Tensor({ height, width }, itype));
output_c_list.emplace_back(Tensor({ height, width }, otype));
output_t_list.emplace_back(Tensor({ width, height }, otype));
scale_list.emplace_back(Tensor({ 1 }, ctype));
amax_list.emplace_back(Tensor({ 1 }, ctype));
scale_inv_list.emplace_back(Tensor({ 1 }, ctype));
auto& input = input_list.back();
auto& scale = scale_list.back();
fillUniform(input);
fillUniform(scale);
*scale.cpu_dptr<float>() += 2.5;
scale.from_cpu();
ref_input_list.emplace_back(height*width);
ref_output_c_list.emplace_back(height*width);
ref_output_t_list.emplace_back(width*height);
std::copy(input.cpu_dptr<InputType>(),
input.cpu_dptr<InputType>() + height * width,
ref_input_list.back().begin());
ref_scale_list[tensor_id] = *scale.cpu_dptr<float>();
ref_height_list[tensor_id] = height;
ref_width_list[tensor_id] = width;
}
// Transformer Engine implementation
auto make_nvte_vector = [](std::vector<Tensor>& tensor_list)
-> std::vector<NVTETensor> {
std::vector<NVTETensor> nvte_tensor_list;
for (auto& tensor : tensor_list) {
nvte_tensor_list.emplace_back(tensor.data());
}
return nvte_tensor_list;
};
nvte_multi_cast_transpose(num_tensors,
make_nvte_vector(input_list).data(),
make_nvte_vector(scale_list).data(),
make_nvte_vector(output_c_list).data(),
make_nvte_vector(output_t_list).data(),
make_nvte_vector(amax_list).data(),
make_nvte_vector(scale_inv_list).data(),
0);
// Reference implementation
compute_ref<InputType, OutputType>(ref_input_list,
ref_output_c_list,
ref_output_t_list,
ref_scale_list,
ref_amax_list,
ref_scale_inv_list,
ref_height_list,
ref_width_list);
// Check correctness
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax",
amax_list[tensor_id],
&ref_amax_list[tensor_id],
atol_amax, rtol_amax);
compareResults("scale_inv",
scale_inv_list[tensor_id],
&ref_scale_inv_list[tensor_id],
atol_amax, rtol_amax);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c",
output_c_list[tensor_id],
ref_output_c_list[tensor_id].data(),
atol, rtol);
compareResults("output_t",
output_t_list[tensor_id],
ref_output_t_list[tensor_id].data(),
atol, rtol);
}
}
} // namespace
class MultiCastTransposeTestSuite
: public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType>> {};
TEST_P(MultiCastTransposeTestSuite, TestMultiCastTranspose) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>();
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MultiCastTransposeTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types)),
[](const testing::TestParamInfo<MultiCastTransposeTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
return name;
});
...@@ -25,6 +25,7 @@ add_library(transformer_engine SHARED ...@@ -25,6 +25,7 @@ add_library(transformer_engine SHARED
transpose/cast_transpose.cu transpose/cast_transpose.cu
transpose/transpose.cu transpose/transpose.cu
transpose/cast_transpose_fusion.cu transpose/cast_transpose_fusion.cu
transpose/multi_cast_transpose.cu
activation/gelu.cu activation/gelu.cu
gemm/cublaslt_gemm.cu gemm/cublaslt_gemm.cu
layer_norm/ln_api.cpp layer_norm/ln_api.cpp
......
...@@ -115,6 +115,33 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, ...@@ -115,6 +115,33 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Cast and transpose multiple tensors.
*
* This function casts each input tensor and produces 2 results:
* - `cast_output` is the result of the cast
* - `transposed_output` is the transposed result of the cast.
*
* \param[in] num_tensors Number of tensors.
* \param[in] input_list List of 2D input tensors.
* \param[in] scale_list Scaling factor to generate outputs.
* \param[out] cast_output_list List of casted tensors. Dimensions
* match tensors in input_list.
* \param[out] transposed_output_list List of casted and transposed
* tensors. Dimensions are transpose
* of tensors in input_list.
* \param[in,out] amax_list AMAX values of the output tensors.
* \param[out] scale_inv_list Inverses of the scaling factors.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_multi_cast_transpose(size_t num_tensors,
const NVTETensor* input_list,
const NVTETensor* scale_list,
NVTETensor* cast_output_list,
NVTETensor* transposed_output_list,
NVTETensor* amax_list,
NVTETensor* scale_inv_list,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transpose.h>
#include <cuda_runtime.h>
#include <iostream>
#include <cfloat>
#include <vector>
#include "../utils.cuh"
#include "../common.h"
namespace transformer_engine {
namespace {
// Parameters to tune
constexpr int n_warps_per_tile = 4;
constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile;
constexpr int desired_load_size = 8;
constexpr int desired_store_size = 8;
constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB
struct MultiCastTransposeArgs {
// (input) Data buffers for input tensors
void* input_list[kMaxTensorsPerKernel];
// (output) Data buffers for cast output tensors
void* output_c_list[kMaxTensorsPerKernel];
// (output) Data buffers for transpose output tensors
void* output_t_list[kMaxTensorsPerKernel];
// (input) Scaling factor for output tensors
void* scale_list[kMaxTensorsPerKernel];
// (output) AMAX's of input tensors
void* amax_list[kMaxTensorsPerKernel];
// (output) Reciprocal of scaling factors
void* scale_inv_list[kMaxTensorsPerKernel];
// Input matrix heights
int num_rows_list[kMaxTensorsPerKernel];
// Input matrix widths
int row_length_list[kMaxTensorsPerKernel];
// Prefix sum (with leading zero) of CUDA blocks needed for each
// tensor
int block_range[kMaxTensorsPerKernel+1];
// Number of tensors being processed by kernel
int num_tensors;
};
template <
int nvec_in,
int nvec_out,
bool aligned,
typename CType,
typename IType,
typename OType>
__global__ void
__launch_bounds__(threads_per_block)
multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
using IVec = Vec<IType, nvec_in>;
using OVecC = Vec<OType, nvec_in>;
using OVecT = Vec<OType, nvec_out>;
// Thread indices
// Note: Block is interpreted as a warp_size x num_warps grid
constexpr int bdimx = THREADS_PER_WARP;
constexpr int bdimy = n_warps_per_tile;
const int tid = threadIdx.x;
const int tidx = tid % bdimx;
const int tidy = tid / bdimx;
const int bid = blockIdx.x;
// Input tensors are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles
constexpr int tile_dim_m = THREADS_PER_WARP * nvec_out;
constexpr int tile_dim_n = THREADS_PER_WARP * nvec_in;
// Number of nvec_out x nvec_in subtiles for each thread to
// load/store
constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
// Find tensor corresponding to block
int tensor_id = 0;
while (args.block_range[tensor_id+1] <= bid) {
++tensor_id;
}
const IType* input = reinterpret_cast<const IType*>(args.input_list[tensor_id]);
OType* output_c = reinterpret_cast<OType*>(args.output_c_list[tensor_id]);
OType* output_t = reinterpret_cast<OType*>(args.output_t_list[tensor_id]);
const CType scale = *reinterpret_cast<CType*>(args.scale_list[tensor_id]);
CType* amax = reinterpret_cast<CType*>(args.amax_list[tensor_id]);
CType* scale_inv = reinterpret_cast<CType*>(args.scale_inv_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
// Find position of tile within tensor
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int tile_id = bid - args.block_range[tensor_id];
const int tile_id_m = tile_id / num_tiles_n;
const int tile_id_n = tile_id % num_tiles_n;
const int tile_row = tile_id_m * tile_dim_m;
const int tile_col = tile_id_n * tile_dim_n;
// Load input and store to registers
// Note: Each thread loads n_iterations subtiles, casts to output
// type, and transposes in registers.
OVecT local_output_t[nvec_in][n_iterations];
CType local_amax = 0;
#pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
#pragma unroll
for (int i2 = 0; i2 < nvec_out; ++i2) {
const int row = tile_row + i1 * nvec_out + i2;
const int col = tile_col + j1 * nvec_in;
IVec local_input;
OVecC local_output_c;
if constexpr (aligned) {
local_input.load_from(&input[row * row_length + col]);
} else {
local_input.clear();
if (row < num_rows) {
#pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) {
if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row * row_length + col + j2];
}
}
}
}
#pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) {
const CType x = CType(local_input.data.elt[j2]);
const OType y = OType(scale * x);
local_output_c.data.elt[j2] = y;
local_output_t[j2][iter].data.elt[i2] = y;
__builtin_assume(local_amax >= 0);
local_amax = fmaxf(fabsf(x), local_amax);
}
if constexpr (aligned) {
local_output_c.store_to(&output_c[row * row_length + col]);
} else {
if (row < num_rows) {
#pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) {
if (col + j2 < row_length) {
output_c[row * row_length + col + j2] = local_output_c.data.elt[j2];
}
}
}
}
}
}
// Copy transposed output from registers to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1];
#pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
shared_output_t[j1][i1] = local_output_t[j2][iter];
}
__syncthreads();
#pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidx;
const int j1 = tidy + iter * bdimy;
const int row = tile_row + i1 * nvec_out;
const int col = tile_col + j1 * nvec_in + j2;
if constexpr (aligned) {
shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]);
} else {
if (col < row_length) {
#pragma unroll
for (int i2 = 0; i2 < nvec_out; ++i2) {
if (row + i2 < num_rows) {
output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2];
}
}
}
}
}
__syncthreads();
}
// Finalize fp8 factors
local_amax = reduce_max<n_warps_per_tile>(local_amax, tidy);
if (tid == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(amax, local_amax);
}
if (tid == 0 && tile_id == 0) {
reciprocal<float>(scale_inv, scale);
}
}
} // namespace
void multi_cast_transpose(const std::vector<Tensor*> input_list,
const std::vector<Tensor*> scale_list,
std::vector<Tensor*> cast_output_list,
std::vector<Tensor*> transposed_output_list,
std::vector<Tensor*> amax_list,
std::vector<Tensor*> scale_inv_list,
cudaStream_t stream) {
// Check that number of tensors is valid
NVTE_CHECK(scale_list.size() == input_list.size(),
"Number of input and scale tensors must match");
NVTE_CHECK(cast_output_list.size() == input_list.size(),
"Number of input and C output tensors must match");
NVTE_CHECK(transposed_output_list.size() == input_list.size(),
"Number of input and T output tensors must match");
NVTE_CHECK(amax_list.size() == input_list.size(),
"Number of input and AMAX tensors must match");
NVTE_CHECK(scale_inv_list.size() == input_list.size(),
"Number of input and scale_inv tensors must match");
if (input_list.empty()) {
return;
}
// Check that tensor properties are valid
DType ctype = DType::kFloat32;
DType itype = input_list[0]->dtype;
DType otype = cast_output_list[0]->dtype;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = *input_list[tensor_id];
const auto& scale = *scale_list[tensor_id];
const auto& cast_output = *cast_output_list[tensor_id];
const auto& transposed_output = *transposed_output_list[tensor_id];
const auto& amax = *amax_list[tensor_id];
const auto& scale_inv = *scale_inv_list[tensor_id];
NVTE_CHECK(input.dtype == itype,
"Input tensor types do not match.");
NVTE_CHECK(scale.dtype == ctype,
"Scale tensor must have Float32 type.");
NVTE_CHECK(cast_output.dtype == otype,
"C output tensor types do not match.");
NVTE_CHECK(transposed_output.dtype == otype,
"T output tensor types do not match.");
NVTE_CHECK(amax.dtype == ctype,
"AMAX tensor must have Float32 type.");
NVTE_CHECK(scale_inv.dtype == ctype,
"scale_inv tensor must have Float32 type.");
NVTE_CHECK(input.shape.size() == 2,
"Input tensor must have 2 dimensions.");
NVTE_CHECK(cast_output.shape == input.shape,
"C output tensor shape does not match input tensor.");
NVTE_CHECK(transposed_output.shape.size() == 2,
"T output tensor shape does not match input tensor.");
NVTE_CHECK(transposed_output.shape[0] == input.shape[1],
"T output tensor shape does not match input tensor.");
NVTE_CHECK(transposed_output.shape[1] == input.shape[0],
"T output tensor shape does not match input tensor.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated.");
NVTE_CHECK(cast_output.dptr != nullptr, "C output is not allocated.");
NVTE_CHECK(transposed_output.dptr != nullptr, "T output is not allocated.");
NVTE_CHECK(amax.dptr != nullptr, "AMAX output is not allocated.");
NVTE_CHECK(scale_inv.dptr != nullptr, "scale_inv output is not allocated.");
}
// Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles
const int tile_dim_m = THREADS_PER_WARP * desired_store_size / typeToSize(otype);
const int tile_dim_n = THREADS_PER_WARP * desired_load_size / typeToSize(itype);
// Add tensors to kernel argument struct
MultiCastTransposeArgs kernel_args_aligned, kernel_args_unaligned;
kernel_args_aligned.num_tensors = 0;
kernel_args_aligned.block_range[0] = 0;
kernel_args_unaligned.num_tensors = 0;
kernel_args_unaligned.block_range[0] = 0;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
// Launch kernel if argument struct is full
if (kernel_args_aligned.num_tensors == kMaxTensorsPerKernel) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);
); // NOLINT(*)
); // NOLINT(*)
kernel_args_aligned.num_tensors = 0;
}
if (kernel_args_unaligned.num_tensors == kMaxTensorsPerKernel) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks = kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);
); // NOLINT(*)
); // NOLINT(*)
kernel_args_unaligned.num_tensors = 0;
}
// Calculate number of thread blocks needed for tensor
const int num_rows = input_list[tensor_id]->shape[0];
const int row_length = input_list[tensor_id]->shape[1];
const int num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m;
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int num_tiles = num_tiles_m * num_tiles_n;
// Figure out whether to use aligned or unaligned kernel
const bool aligned = ((num_tiles_m * tile_dim_m == num_rows)
&& (num_tiles_n * tile_dim_n == row_length));
auto& kernel_args = aligned ? kernel_args_aligned : kernel_args_unaligned;
// Add tensor to kernel argument struct
const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input_list[tensor_id]->dptr);
kernel_args.output_c_list[pos] = cast_output_list[tensor_id]->dptr;
kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->dptr;
kernel_args.scale_list[pos] = const_cast<void*>(scale_list[tensor_id]->dptr);
kernel_args.amax_list[pos] = amax_list[tensor_id]->dptr;
kernel_args.scale_inv_list[pos] = scale_inv_list[tensor_id]->dptr;
kernel_args.num_rows_list[pos] = num_rows;
kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos+1] = kernel_args.block_range[pos] + num_tiles;
kernel_args.num_tensors++;
}
// Launch kernel
if (kernel_args_aligned.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);
); // NOLINT(*)
); // NOLINT(*)
}
if (kernel_args_unaligned.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks = kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);
); // NOLINT(*)
); // NOLINT(*)
}
}
} // namespace transformer_engine
void nvte_multi_cast_transpose(size_t num_tensors,
const NVTETensor* input_list,
const NVTETensor* scale_list,
NVTETensor* cast_output_list,
NVTETensor* transposed_output_list,
NVTETensor* amax_list,
NVTETensor* scale_inv_list,
cudaStream_t stream) {
using namespace transformer_engine;
std::vector<Tensor*> input_list_, scale_list_,
cast_output_list_, transposed_output_list_, amax_list_, scale_inv_list_;
for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i])));
scale_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(scale_list[i])));
cast_output_list_.push_back(reinterpret_cast<Tensor*>(cast_output_list[i]));
transposed_output_list_.push_back(reinterpret_cast<Tensor*>(transposed_output_list[i]));
amax_list_.push_back(reinterpret_cast<Tensor*>(amax_list[i]));
scale_inv_list_.push_back(reinterpret_cast<Tensor*>(scale_inv_list[i]));
}
multi_cast_transpose(input_list_,
scale_list_,
cast_output_list_,
transposed_output_list_,
amax_list_,
scale_inv_list_,
stream);
}
...@@ -352,3 +352,81 @@ void dispatch_bgrad_dgelu_cast_transpose_fusion( ...@@ -352,3 +352,81 @@ void dispatch_bgrad_dgelu_cast_transpose_fusion(
amax_cu.data(), dbias_cu.data(), scale_inv_cu.data(), amax_cu.data(), dbias_cu.data(), scale_inv_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream()); workspace.data(), at::cuda::getCurrentCUDAStream());
} }
void dispatch_multi_cast_transpose(
std::vector<void*> input_dptr_list, // i
const std::vector<std::vector<size_t>>& input_shape_list,
const std::vector<transformer_engine::DType>& input_type_list,
std::vector<void*> scale_dptr_list, // i
const std::vector<std::vector<size_t>>& scale_shape_list,
const std::vector<transformer_engine::DType>& scale_type_list,
std::vector<void*> cast_output_dptr_list, // o
const std::vector<std::vector<size_t>>& cast_output_shape_list,
const std::vector<transformer_engine::DType>& cast_output_type_list,
std::vector<void*> transposed_output_dptr_list, // o
const std::vector<std::vector<size_t>>& transposed_output_shape_list,
const std::vector<transformer_engine::DType>& transposed_output_type_list,
std::vector<void*> amax_dptr_list, // o
const std::vector<std::vector<size_t>>& amax_shape_list,
const std::vector<transformer_engine::DType>& amax_type_list,
std::vector<void*> scale_inv_dptr_list, // o
const std::vector<std::vector<size_t>>& scale_inv_shape_list,
const std::vector<transformer_engine::DType>& scale_inv_type_list
) {
transformer_engine::TensorWrapper workspace;
// Construct TE tensors
std::vector<NVTETensor> input_list, scale_list,
cast_output_list, transposed_output_list, amax_list, scale_inv_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr,
const std::vector<size_t>& shape,
transformer_engine::DType dtype)
-> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype));
return tensor_wrappers.back().data();
};
for (size_t i = 0; i < input_dptr_list.size(); ++i) {
input_list.emplace_back(make_tensor(input_dptr_list[i],
input_shape_list[i],
input_type_list[i]));
scale_list.emplace_back(make_tensor(scale_dptr_list[i],
scale_shape_list[i],
scale_type_list[i]));
cast_output_list.emplace_back(make_tensor(cast_output_dptr_list[i],
cast_output_shape_list[i],
cast_output_type_list[i]));
transposed_output_list.emplace_back(make_tensor(transposed_output_dptr_list[i],
transposed_output_shape_list[i],
transposed_output_type_list[i]));
amax_list.emplace_back(make_tensor(amax_dptr_list[i],
amax_shape_list[i],
amax_type_list[i]));
scale_inv_list.emplace_back(make_tensor(scale_inv_dptr_list[i],
scale_inv_shape_list[i],
scale_inv_type_list[i]));
}
// Check tensor lists
NVTE_CHECK(scale_list.size() == input_list.size(),
"Number of input and scale tensors must match");
NVTE_CHECK(cast_output_list.size() == input_list.size(),
"Number of input and C output tensors must match");
NVTE_CHECK(transposed_output_list.size() == input_list.size(),
"Number of input and T output tensors must match");
NVTE_CHECK(amax_list.size() == input_list.size(),
"Number of input and AMAX tensors must match");
NVTE_CHECK(scale_inv_list.size() == input_list.size(),
"Number of input and scale_inv tensors must match");
// Launch TE kernel
nvte_multi_cast_transpose(input_list.size(),
input_list.data(),
scale_list.data(),
cast_output_list.data(),
transposed_output_list.data(),
amax_list.data(),
scale_inv_list.data(),
at::cuda::getCurrentCUDAStream());
}
...@@ -269,4 +269,26 @@ void dispatch_bgrad_dgelu_cast_transpose_fusion( ...@@ -269,4 +269,26 @@ void dispatch_bgrad_dgelu_cast_transpose_fusion(
); );
void dispatch_multi_cast_transpose(
std::vector<void*> input_dptr_list, // i
const std::vector<std::vector<size_t>>& input_shape_list,
const std::vector<transformer_engine::DType>& input_type_list,
std::vector<void*> scale_dptr_list, // i
const std::vector<std::vector<size_t>>& scale_shape_list,
const std::vector<transformer_engine::DType>& scale_type_list,
std::vector<void*> cast_output_dptr_list, // o
const std::vector<std::vector<size_t>>& cast_output_shape_list,
const std::vector<transformer_engine::DType>& cast_output_type_list,
std::vector<void*> transposed_output_dptr_list, // o
const std::vector<std::vector<size_t>>& transposed_output_shape_list,
const std::vector<transformer_engine::DType>& transposed_output_type_list,
std::vector<void*> amax_dptr_list, // o
const std::vector<std::vector<size_t>>& amax_shape_list,
const std::vector<transformer_engine::DType>& amax_type_list,
std::vector<void*> scale_inv_dptr_list, // o
const std::vector<std::vector<size_t>>& scale_inv_shape_list,
const std::vector<transformer_engine::DType>& scale_inv_type_list
);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
...@@ -173,6 +173,96 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, ...@@ -173,6 +173,96 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
} }
void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_list,
std::vector<at::Tensor> cast_output_list,
std::vector<at::Tensor> transposed_output_list,
std::vector<at::Tensor> amax_list,
std::vector<at::Tensor> scale_inv_list,
transformer_engine::DType otype
) {
using namespace transformer_engine;
// Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, scale_dptr_list,
cast_output_dptr_list, transposed_output_dptr_list,
amax_dptr_list, scale_inv_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, scale_shape_list,
cast_output_shape_list, transposed_output_shape_list,
amax_shape_list, scale_inv_shape_list;
std::vector<transformer_engine::DType> input_type_list, scale_type_list,
cast_output_type_list, transposed_output_type_list,
amax_type_list, scale_inv_type_list;
auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor,
std::vector<void*>& dptr_list,
std::vector<std::vector<size_t>>& shape_list) {
dptr_list.push_back(tensor.data_ptr());
shape_list.push_back({});
for (int d = 0; d < tensor.dim(); ++d) {
shape_list.back().push_back(tensor.size(d));
}
};
auto extract_tensor_props = [](at::Tensor& tensor,
std::vector<void*>& dptr_list,
std::vector<std::vector<size_t>>& shape_list,
std::vector<transformer_engine::DType>& type_list) {
dptr_list.push_back(tensor.data_ptr());
shape_list.push_back({});
for (int d = 0; d < tensor.dim(); ++d) {
shape_list.back().push_back(tensor.size(d));
}
type_list.push_back(GetTransformerEngineDType(tensor.scalar_type()));
};
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
extract_tensor_props(input_list[tensor_id],
input_dptr_list,
input_shape_list,
input_type_list);
extract_tensor_props(scale_list[tensor_id],
scale_dptr_list,
scale_shape_list,
scale_type_list);
extract_tensor_props_skip_dtype(cast_output_list[tensor_id],
cast_output_dptr_list,
cast_output_shape_list);
cast_output_type_list.push_back(otype);
extract_tensor_props_skip_dtype(transposed_output_list[tensor_id],
transposed_output_dptr_list,
transposed_output_shape_list);
transposed_output_type_list.push_back(otype);
extract_tensor_props(amax_list[tensor_id],
amax_dptr_list,
amax_shape_list,
amax_type_list);
extract_tensor_props(scale_inv_list[tensor_id],
scale_inv_dptr_list,
scale_inv_shape_list,
scale_inv_type_list);
}
// Launch TE kernel
dispatch_multi_cast_transpose(
input_dptr_list,
input_shape_list,
input_type_list,
scale_dptr_list,
scale_shape_list,
scale_type_list,
cast_output_dptr_list,
cast_output_shape_list,
cast_output_type_list,
transposed_output_dptr_list,
transposed_output_shape_list,
transposed_output_type_list,
amax_dptr_list,
amax_shape_list,
amax_type_list,
scale_inv_dptr_list,
scale_inv_shape_list,
scale_inv_type_list);
}
at::Tensor fp8_transpose(at::Tensor input, at::Tensor fp8_transpose(at::Tensor input,
transformer_engine::DType otype transformer_engine::DType otype
) { ) {
...@@ -403,6 +493,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -403,6 +493,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Fused Cast + Transpose + BGRAD"); "Fused Cast + Transpose + BGRAD");
m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu, m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu,
"Fused Cast + Transpose + BGRAD + DGELU"); "Fused Cast + Transpose + BGRAD + DGELU");
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose");
m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8"); m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8");
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8"); m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8");
m.def("te_gemm", &te_gemm, "CublasLt GEMM"); m.def("te_gemm", &te_gemm, "CublasLt GEMM");
......
...@@ -54,6 +54,16 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, ...@@ -54,6 +54,16 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
); );
void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_list,
std::vector<at::Tensor> cast_output_list,
std::vector<at::Tensor> transposed_output_list,
std::vector<at::Tensor> amax_output_list,
std::vector<at::Tensor> scale_inv_output_list,
transformer_engine::DType otype
);
at::Tensor fp8_transpose(at::Tensor input, at::Tensor fp8_transpose(at::Tensor input,
transformer_engine::DType otype transformer_engine::DType otype
); );
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment