Unverified Commit 23cf4ff9 authored by xiaoxi-wangfj's avatar xiaoxi-wangfj Committed by GitHub
Browse files

[PyTorch|common] Optimize unpadding kernel for FP8 (#1866)



* [PyTorch|common] Implement unpadding kernel for FP8

1. Add multi-tensor unpadding kernel
2. Replace split+cat with unpadding kernel in Fp8Padding and Fp8Unpadding
3. Add unpadding with padding unit tests
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add license
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Update padding.cu
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent c30e961f
......@@ -25,6 +25,7 @@ add_executable(test_operator
test_memset.cu
test_multi_cast_transpose.cu
test_multi_padding.cu
test_multi_unpadding.cu
test_causal_softmax.cu
test_swizzle.cu
../test_common.cu)
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include <cstdio>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/padding.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_unpadding_ref(const std::vector<std::vector<InputType>>& input_list,
std::vector<std::vector<OutputType>>& output_list,
const std::vector<size_t>& height_list,
const std::vector<size_t>& width_list,
const std::vector<int>& padded_height_list) {
using compute_t = float;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = input_list[tensor_id];
auto& output = output_list[tensor_id];
const size_t height = height_list[tensor_id];
const size_t width = width_list[tensor_id];
const size_t padded_height = padded_height_list[tensor_id];
// Only copy the valid (unpadded) portion
for (size_t i = 0; i < height; ++i) {
for (size_t j = 0; j < width; ++j) {
const compute_t x = static_cast<compute_t>(input[i * width + j]);
const OutputType y = static_cast<OutputType>(x);
output[i * width + j] = y;
}
}
}
}
template <typename InputType, typename OutputType>
void performUnpaddingTest() {
using namespace test;
const DType itype = TypeInfo<InputType>::dtype;
const DType otype = TypeInfo<OutputType>::dtype;
const std::vector<std::pair<size_t, size_t>> tensor_dims = {{1,1},
{1,768},
{768,1},
{768,768},
{43,43},
{43,256},
{256,43},
{256,256}};
const size_t num_tensors = tensor_dims.size();
constexpr int align = 16;
// Buffers for Transformer Engine implementation
std::vector<Tensor> padded_input_list, unpadded_output_list;
// Buffers for reference implementation
std::vector<std::vector<InputType>> ref_padded_input_list;
std::vector<std::vector<OutputType>> ref_unpadded_output_list;
std::vector<size_t> ref_height_list(num_tensors), ref_width_list(num_tensors);
std::vector<int> ref_padded_height_list(num_tensors);
// Initialize buffers
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
const size_t original_height = tensor_dims[tensor_id].first;
const size_t width = tensor_dims[tensor_id].second;
const size_t padded_height = (original_height + align - 1) / align * align;
// Input is padded tensor (padded_height x width)
padded_input_list.emplace_back(
Tensor("padded_input_" + std::to_string(tensor_id),
std::vector<size_t>{padded_height, width}, itype));
// Output is unpadded tensor (original_height x width)
unpadded_output_list.emplace_back(
Tensor("unpadded_output_" + std::to_string(tensor_id),
std::vector<size_t>{original_height, width}, otype));
auto& padded_input = padded_input_list.back();
auto& unpadded_output = unpadded_output_list.back();
// Fill padded input with random data (including padding area)
fillUniform(&padded_input);
setRandomScale(&unpadded_output);
// Initialize reference buffers
ref_padded_input_list.emplace_back(padded_height * width);
ref_unpadded_output_list.emplace_back(original_height * width);
// Copy data to reference buffers
std::copy(padded_input.rowwise_cpu_dptr<InputType>(),
padded_input.rowwise_cpu_dptr<InputType>() + padded_height * width,
ref_padded_input_list.back().begin());
ref_height_list[tensor_id] = original_height;
ref_width_list[tensor_id] = width;
ref_padded_height_list[tensor_id] = padded_height;
}
// Transformer Engine implementation
auto make_nvte_vector = [](std::vector<Tensor>& tensor_list)
-> std::vector<NVTETensor> {
std::vector<NVTETensor> nvte_tensor_list;
for (auto& tensor : tensor_list) {
nvte_tensor_list.emplace_back(tensor.data());
}
return nvte_tensor_list;
};
// Convert height_list to int for the API
std::vector<int> original_height_list_int(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
original_height_list_int[i] = static_cast<int>(ref_height_list[i]);
}
// Call unpadding API
nvte_multi_unpadding(num_tensors,
make_nvte_vector(padded_input_list).data(),
make_nvte_vector(unpadded_output_list).data(),
original_height_list_int.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
// Reference implementation
compute_unpadding_ref<InputType, OutputType>(ref_padded_input_list,
ref_unpadded_output_list,
ref_height_list,
ref_width_list,
ref_padded_height_list);
// Check correctness
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
auto [atol, rtol] = getTolerances(otype);
compareResults("unpadded_output",
unpadded_output_list[tensor_id],
ref_unpadded_output_list[tensor_id].data(),
true,
atol, rtol);
}
}
} // namespace
class MultiUnpaddingTestSuite
: public ::testing::TestWithParam<transformer_engine::DType> {};
TEST_P(MultiUnpaddingTestSuite, TestMultiUnpadding) {
using namespace transformer_engine;
using namespace test;
const DType input_type = GetParam();
const DType output_type = input_type;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performUnpaddingTest<InputType, OutputType>();
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MultiUnpaddingTestSuite,
::testing::ValuesIn(test::all_fp_types),
[](const testing::TestParamInfo<MultiUnpaddingTestSuite::ParamType>& info) {
std::string name = test::typeName(info.param);
return name;
});
......@@ -44,6 +44,33 @@ extern "C" {
void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
const int* padded_num_rows_list, cudaStream_t stream);
/*! \brief Unpadding multiple tensors (reverse operation of padding).
*
* NOTE: Unpadding mode only removes bottom rows.
*
* For example, 4x3 matrix unpad to 3x3 matrix.
*
* source
* | 1 | 2 | 3 |
* | 4 | 5 | 6 |
* | 7 | 8 | 9 |
* | 0 | 0 | 0 |
*
* destination
* | 1 | 2 | 3 |
* | 4 | 5 | 6 |
* | 7 | 8 | 9 |
*
* \param[in] num_tensors Number of tensors.
* \param[in] input_list List of 2D padded input tensors.
* \param[in,out] output_list List of unpadded tensors. Dimensions
* match original unpadded tensors.
* \param[in] unpadded_num_rows_list List of unpadded num rows corresponding to input tensors.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_multi_unpadding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
const int* unpadded_num_rows_list, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -126,6 +126,83 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
}
}
template <int nvec, typename Type>
__global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(MultiPaddingArgs args) {
using Vec = Vec<Type, nvec>;
// Thread indices
// Note: Block is interpreted as a warp_size x num_warps grid
constexpr int bdimx = THREADS_PER_WARP;
constexpr int bdimy = n_warps_per_tile;
const int tid = threadIdx.x;
const int tidx = tid % bdimx;
const int tidy = tid / bdimx;
const int bid = blockIdx.x;
// Input tensors are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
constexpr int tile_dim_m = THREADS_PER_WARP * nvec;
constexpr int tile_dim_n = THREADS_PER_WARP * nvec;
// Number of nvec x nvec subtiles for each thread to
// load/store
constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
// Find tensor corresponding to block
int tensor_id = 0;
while (args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
const Type* input = reinterpret_cast<const Type*>(args.input_list[tensor_id]);
Type* output = reinterpret_cast<Type*>(args.output_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
// Find position of tile within tensor
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int tile_id = bid - args.block_range[tensor_id];
const int tile_id_m = tile_id / num_tiles_n;
const int tile_id_n = tile_id % num_tiles_n;
const int tile_row = tile_id_m * tile_dim_m;
const int tile_col = tile_id_n * tile_dim_n;
// Load input and store to registers
// Note: Each thread loads n_iterations subtiles, casts to output
// type, and transposes in registers.
Type local_zero = static_cast<Type>(0.f);
#pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
#pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2;
const int col = tile_col + j1 * nvec;
Vec local_input;
Vec local_output;
local_input.clear();
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row * row_length + col + j2];
}
}
}
#pragma unroll
for (int j2 = 0; j2 < nvec; ++j2) {
local_output.data.elt[j2] = local_input.data.elt[j2];
}
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
output[row * row_length + col + j2] = local_output.data.elt[j2];
}
}
}
}
}
}
} // namespace
void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> output_list,
......@@ -202,6 +279,78 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o
}
}
void multi_unpadding(const std::vector<Tensor*> input_list, std::vector<Tensor*> output_list,
const std::vector<int> unpadded_num_rows_list, cudaStream_t stream) {
// Check that number of tensors is valid
NVTE_CHECK(output_list.size() == input_list.size(),
"Number of input and output tensors must match");
if (input_list.empty()) {
return;
}
// Check that tensor properties are valid
DType type = input_list[0]->data.dtype;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = *input_list[tensor_id];
const auto& output = *output_list[tensor_id];
CheckInputTensor(input, "multi_unpadding_input_" + std::to_string(tensor_id));
CheckInputTensor(output, "multi_unpadding_output_" + std::to_string(tensor_id));
NVTE_CHECK(input.data.dtype == type, "Input tensor types do not match.");
NVTE_CHECK(output.data.dtype == type, "Output tensor types do not match.");
NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions.");
NVTE_CHECK(output.data.shape[0] == unpadded_num_rows_list[tensor_id],
"output tensor shape does not match padded input shape.");
}
// Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
// Add tensors to kernel argument struct
MultiPaddingArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.block_range[0] = 0;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
// Launch kernel if argument struct is full
if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type);
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_unpadding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
kernel_args.num_tensors = 0;
}
// Calculate number of thread blocks needed for tensor
const int num_rows = unpadded_num_rows_list[tensor_id];
const int row_length = input_list[tensor_id]->data.shape[1];
const int num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m;
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int num_tiles = num_tiles_m * num_tiles_n;
// Add tensor to kernel argument struct
const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input_list[tensor_id]->data.dptr);
kernel_args.output_list[pos] = output_list[tensor_id]->data.dptr;
kernel_args.num_rows_list[pos] = num_rows;
kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles;
kernel_args.num_tensors++;
}
// Launch kernel
if (kernel_args.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type);
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_unpadding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
}
}
} // namespace transformer_engine
void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
......@@ -217,3 +366,17 @@ void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETe
}
multi_padding(input_list_, output_list_, padded_num_rows_list_, stream);
}
void nvte_multi_unpadding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
const int* unpadded_num_rows_list, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_unpadding);
using namespace transformer_engine;
std::vector<Tensor*> input_list_, output_list_;
std::vector<int> unpadded_num_rows_list_;
for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(convertNVTETensorCheck(input_list[i]));
output_list_.push_back(convertNVTETensorCheck(output_list[i]));
unpadded_num_rows_list_.push_back(unpadded_num_rows_list[i]);
}
multi_unpadding(input_list_, output_list_, unpadded_num_rows_list_, stream);
}
......@@ -368,6 +368,9 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list);
void fused_multi_row_unpadding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> unpadded_input_row_list);
/***************************************************************************************************
* NVSHMEM APIs
**************************************************************************************************/
......
......@@ -81,4 +81,77 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
});
}
void fused_multi_row_unpadding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> unpadded_input_row_list) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(input_row_list.size() == unpadded_input_row_list.size(),
"Number of input row list and padded row list must match.");
NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2.");
NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2.");
const auto num_tensors = input_row_list.size();
// Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, output_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, output_shape_list;
std::vector<transformer_engine::DType> input_type_list;
void* d_input_ptr = reinterpret_cast<void*>(input.data_ptr());
void* d_output_ptr = reinterpret_cast<void*>(output.data_ptr());
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
input_dptr_list.push_back(d_input_ptr);
output_dptr_list.push_back(d_output_ptr);
// Move the input pointer to the next split.
char* input_char_ptr = reinterpret_cast<char*>(d_input_ptr);
const size_t input_dptr_offset =
input_row_list[tensor_id] * input.size(1) * input.element_size();
input_char_ptr += input_dptr_offset;
d_input_ptr = reinterpret_cast<void*>(input_char_ptr);
input_shape_list.push_back({input_row_list[tensor_id], static_cast<size_t>(input.size(1))});
input_type_list.push_back(GetTransformerEngineDType(input.scalar_type()));
// Move the output pointer to the next split.
char* output_char_ptr = reinterpret_cast<char*>(d_output_ptr);
const size_t output_dptr_offset =
unpadded_input_row_list[tensor_id] * output.size(1) * output.element_size();
output_char_ptr += output_dptr_offset;
d_output_ptr = reinterpret_cast<void*>(output_char_ptr);
output_shape_list.push_back(
{unpadded_input_row_list[tensor_id], static_cast<size_t>(output.size(1))});
}
// Construct TE tensors
std::vector<NVTETensor> nvte_input_list, nvte_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype) -> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype));
return tensor_wrappers.back().data();
};
std::vector<int> unpadded_num_rows_list;
for (size_t i = 0; i < input_dptr_list.size(); ++i) {
if (input_dptr_list[i] == nullptr || input_row_list[i] == 0) continue;
nvte_input_list.emplace_back(
make_tensor(input_dptr_list[i], input_shape_list[i], input_type_list[i]));
nvte_output_list.emplace_back(
make_tensor(output_dptr_list[i], output_shape_list[i], input_type_list[i]));
unpadded_num_rows_list.emplace_back(unpadded_input_row_list[i]);
}
// Check tensor lists
NVTE_CHECK(nvte_output_list.size() == nvte_input_list.size(),
"Number of input and output tensors must match");
NVTE_CHECK(unpadded_num_rows_list.size() == nvte_input_list.size() &&
"Number of input and padded row list must match");
// Launch TE kernel
nvte_multi_unpadding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(),
unpadded_num_rows_list.data(), at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -232,6 +232,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("out_dtype"), py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding,
"Fused Multi-tensor padding", py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding,
"Fused Multi-tensor unpadding", py::call_guard<py::gil_scoped_release>());
// attention kernels
m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd,
......
......@@ -53,15 +53,16 @@ class _Fp8Padding(torch.autograd.Function):
if ctx.requires_dgrad:
grad_output = grad_output.contiguous()
grad_output_mats = torch.split(
grad_output.view(-1, grad_output.shape[-1]), ctx.padded_m_splits
in_features = grad_output.shape[-1]
# Allocate cast and transpose output tensor
total_row = sum(ctx.m_splits)
grad_input = torch.empty(
[total_row, in_features], dtype=grad_output.dtype, device=grad_output.device
)
grad_input = torch.cat(
[
grad_output_mat[: ctx.m_splits[i]]
for i, grad_output_mat in enumerate(grad_output_mats)
],
dim=0,
tex.fused_multi_row_unpadding(
grad_output.view(-1, in_features), grad_input, ctx.padded_m_splits, ctx.m_splits
)
return (grad_input, None, None, None)
......
......@@ -29,10 +29,13 @@ class _Fp8Unpadding(torch.autograd.Function):
is_grad_enabled: bool,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits)
out_ret = torch.cat(
[grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0
)
in_features = inp.shape[-1]
# Allocate cast and transpose output tensor
total_row = sum(m_splits)
out_ret = torch.empty([total_row, in_features], dtype=inp.dtype, device=inp.device)
tex.fused_multi_row_unpadding(inp.view(-1, in_features), out_ret, padded_m_splits, m_splits)
if is_grad_enabled:
ctx.m_splits = m_splits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment