Unverified Commit 215db88d authored by Ruibin Cheung's avatar Ruibin Cheung Committed by GitHub
Browse files

[PyTorch] Implement Fp8 padding and unpadding module (#1129)



* [TE/PyTorch][MoE] Add FP8 padding and unpadding module 

 1. Add multi-tensor padding kernel for FP8 with padding size = 16.
 2. Add FP8Padding and Fp8Unpadding module
 3. Add Padded GroupedLinear unit tests

---------
Signed-off-by: default avatarbeinggod <zhangruibin@01.ai>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 454e3895
......@@ -13,6 +13,7 @@ add_executable(test_operator
test_layernorm.cu
test_rmsnorm.cu
test_multi_cast_transpose.cu
test_multi_padding.cu
test_causal_softmax.cu
../test_common.cu)
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include <cstdio>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/padding.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref(const std::vector<std::vector<InputType>>& input_list,
std::vector<std::vector<OutputType>>& output_list,
const std::vector<size_t>& height_list,
const std::vector<size_t>& width_list,
const std::vector<int>& padded_height_list) {
using compute_t = float;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = input_list[tensor_id];
auto& output = output_list[tensor_id];
const size_t height = height_list[tensor_id];
const size_t width = width_list[tensor_id];
const size_t padded_height = padded_height_list[tensor_id];
for (size_t i = 0; i < padded_height; ++i) {
if (i < height) {
for (size_t j = 0; j < width; ++j) {
const compute_t x = static_cast<compute_t>(input[i * width + j]);
const OutputType y = static_cast<OutputType>(x);
output[i * width + j] = y;
}
} else {
for (size_t j = 0; j < width; ++j) {
output[i * width + j] = static_cast<OutputType>(0.f);
}
}
}
}
}
template <typename InputType, typename OutputType>
void performTest() {
using namespace test;
const DType itype = TypeInfo<InputType>::dtype;
const DType otype = TypeInfo<OutputType>::dtype;
const std::vector<std::pair<size_t, size_t>> tensor_dims = {{1,1},
{1,768},
{768,1},
{768,768},
{43,43},
{43,256},
{256,43},
{256,256}};
const size_t num_tensors = tensor_dims.size();
constexpr int align = 16;
// Buffers for Transformer Engine implementation
std::vector<Tensor> input_list, output_list, output_t_list;
// Buffers for reference implementation
std::vector<std::vector<InputType>> ref_input_list;
std::vector<std::vector<OutputType>> ref_output_list;
std::vector<size_t> ref_height_list(num_tensors), ref_width_list(num_tensors);
std::vector<int> ref_padded_height_list(num_tensors);
// Initialize buffers
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
const size_t height = tensor_dims[tensor_id].first;
const size_t width = tensor_dims[tensor_id].second;
const size_t padded_height = (height + align - 1) / align * align;
input_list.emplace_back(Tensor({ height, width }, itype));
output_list.emplace_back(Tensor({ padded_height, width }, otype));
auto& input = input_list.back();
auto& output = output_list.back();
fillUniform(&input);
setRandomScale(&output);
ref_input_list.emplace_back(height*width);
ref_output_list.emplace_back(padded_height*width);
std::copy(input.cpu_dptr<InputType>(),
input.cpu_dptr<InputType>() + height * width,
ref_input_list.back().begin());
ref_height_list[tensor_id] = height;
ref_width_list[tensor_id] = width;
ref_padded_height_list[tensor_id] = padded_height;
}
// Transformer Engine implementation
auto make_nvte_vector = [](std::vector<Tensor>& tensor_list)
-> std::vector<NVTETensor> {
std::vector<NVTETensor> nvte_tensor_list;
for (auto& tensor : tensor_list) {
nvte_tensor_list.emplace_back(tensor.data());
}
return nvte_tensor_list;
};
nvte_multi_padding(num_tensors,
make_nvte_vector(input_list).data(),
make_nvte_vector(output_list).data(),
ref_padded_height_list.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
// Reference implementation
compute_ref<InputType, OutputType>(ref_input_list,
ref_output_list,
ref_height_list,
ref_width_list,
ref_padded_height_list);
// Check correctness
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
auto [atol, rtol] = getTolerances(otype);
compareResults("output",
output_list[tensor_id],
ref_output_list[tensor_id].data(),
atol, rtol);
}
}
} // namespace
class MultiPaddingTestSuite
: public ::testing::TestWithParam<
transformer_engine::DType> {};
TEST_P(MultiPaddingTestSuite, TestMultiPaddingTranspose) {
using namespace transformer_engine;
using namespace test;
const DType input_type = GetParam();
const DType output_type = input_type;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>();
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MultiPaddingTestSuite,
::testing::ValuesIn(test::all_fp_types),
[](const testing::TestParamInfo<MultiPaddingTestSuite::ParamType>& info) {
std::string name = test::typeName(info.param);
return name;
});
......@@ -7,6 +7,7 @@ import os
from typing import Dict, List, Optional
import pytest
import copy
import random
import torch
import torch.nn as nn
......@@ -30,6 +31,8 @@ from transformer_engine.pytorch import (
TransformerLayer,
LayerNorm,
InferenceParams,
Fp8Padding,
Fp8Unpadding,
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm
......@@ -354,6 +357,40 @@ class TorchSquaredRELU(nn.Module):
return (input > 0) * input * input
class TorchGroupedLinearWithPadding(nn.Module):
def __init__(
self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8
) -> None:
super().__init__()
self.padding = Fp8Padding(num_gemms)
self.linear_fn = GroupedLinear(
num_gemms,
in_features,
out_features,
bias=bias,
params_dtype=params_dtype,
parallel_mode=parallel_mode,
device="cuda",
)
self.unpadding = Fp8Unpadding(num_gemms)
self.fp8 = fp8
def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor:
if self.fp8:
orig_m_splits = m_splits
inp, m_splits = self.padding(inp, m_splits)
out = self.linear_fn(inp, m_splits)
if self.fp8:
out = self.unpadding(out, orig_m_splits)
return out
_supported_act = {
"geglu": nn.GELU(approximate="tanh"),
"gelu": nn.GELU(approximate="tanh"),
......@@ -1328,6 +1365,158 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
)
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):
def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
"""Padding tensor shapes to multiples of 16."""
padded_tokens_per_expert = [
(num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert
]
hidden_states = torch.split(hidden_states, tokens_per_expert)
padded_hidden_states = []
for hidden_state, actual_num_tokens, padded_num_tokens in zip(
hidden_states, tokens_per_expert, padded_tokens_per_expert
):
padded_hidden_states.append(hidden_state)
if padded_num_tokens > actual_num_tokens:
pad_tensor = torch.zeros(
padded_num_tokens - actual_num_tokens,
hidden_state.shape[1],
dtype=hidden_state.dtype,
device=hidden_state.device,
)
padded_hidden_states.append(pad_tensor)
padded_hidden_states = torch.cat(padded_hidden_states, dim=0)
return padded_hidden_states, padded_tokens_per_expert
def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert):
inputmats = torch.split(
padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert
)
hidden_states = torch.cat(
[
grad_output_mat[: actual_tokens_per_expert[i]]
for i, grad_output_mat in enumerate(inputmats)
],
dim=0,
)
return hidden_states
def _generate_random_numbers(n, total_sum):
if n <= 0:
return []
# reset seed
random.seed(seed)
breaks = sorted(random.sample(range(1, total_sum), n - 1))
random_numbers = (
[breaks[0]]
+ [breaks[i] - breaks[i - 1] for i in range(1, n - 1)]
+ [total_sum - breaks[-1]]
)
return random_numbers
reset_rng_states()
if fp8:
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len * bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs)
with fp8_autocast(enabled=fp8):
if isinstance(block, TorchGroupedLinearWithPadding):
out = block(inp_hidden_states, m_splits)
else:
if fp8:
padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8(
inp_hidden_states, m_splits
)
padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits)
out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits)
else:
out = block(inp_hidden_states, m_splits)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", [True])
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy(
dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params):
grouped_linear = TorchGroupedLinearWithPadding(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
fp8=fp8,
).eval()
with fp8_model_init(enabled=fp8 and fp8_model_params):
ref_grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
).eval()
# Share params
with torch.no_grad():
inner_grouped_linear = grouped_linear.linear_fn
for i in range(num_gemms):
setattr(
ref_grouped_linear,
f"weight{i}",
Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
)
outputs = _test_padding_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, fp8
)
outputs_ref = _test_padding_grouped_linear_accuracy(
ref_grouped_linear, num_gemms, bs, dtype, config, fp8
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
reset_rng_states()
......
......@@ -71,6 +71,7 @@ list(APPEND transformer_engine_SOURCES
rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
rmsnorm/rmsnorm_fwd_cuda_kernel.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_runtime.cpp
util/rtc.cpp
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file padding.h
* \brief Functions handling padding.
*/
#ifndef TRANSFORMER_ENGINE_PADDING_H_
#define TRANSFORMER_ENGINE_PADDING_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Padding multiple tensors.
*
* NOTE: Padding mode only support bottom.
*
* For example, 3x3 matrix pad to 4x3 matrix.
*
* source
* | 1 | 2 | 3 |
* | 4 | 5 | 6 |
* | 7 | 8 | 9 |
*
* destination
* | 1 | 2 | 3 |
* | 4 | 5 | 6 |
* | 7 | 8 | 9 |
* | 0 | 0 | 0 |
*
* \param[in] num_tensors Number of tensors.
* \param[in] input_list List of 2D input tensors.
* \param[in,out] output_list List of padded tensors. Dimensions
* match tensors in input_list.
* \param[in] padded_num_rows_list List of padded num rows corresponding to input tensors.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
const int* padded_num_rows_list, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_PADDING_H_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/padding.h>
#include <cfloat>
#include <iostream>
#include <vector>
#include "../common.h"
#include "../utils.cuh"
namespace transformer_engine {
namespace {
// Parameters to tune
constexpr int n_warps_per_tile = 4;
constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile;
constexpr int desired_load_store_size = 8;
constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB
struct MultiPaddingArgs {
// (input) Data buffers for input tensors
void* input_list[kMaxTensorsPerKernel];
// (output) Data buffers for cast output tensors
void* output_list[kMaxTensorsPerKernel];
// Input matrix heights
int num_rows_list[kMaxTensorsPerKernel];
// Input matrix heights (padded)
int padded_num_rows_list[kMaxTensorsPerKernel];
// Input matrix widths
int row_length_list[kMaxTensorsPerKernel];
// tensor
int block_range[kMaxTensorsPerKernel + 1];
// Number of tensors being processed by kernel
int num_tensors;
};
template <int nvec, typename Type>
__global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiPaddingArgs args) {
using Vec = Vec<Type, nvec>;
// Thread indices
// Note: Block is interpreted as a warp_size x num_warps grid
constexpr int bdimx = THREADS_PER_WARP;
constexpr int bdimy = n_warps_per_tile;
const int tid = threadIdx.x;
const int tidx = tid % bdimx;
const int tidy = tid / bdimx;
const int bid = blockIdx.x;
// Input tensors are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
constexpr int tile_dim_m = THREADS_PER_WARP * nvec;
constexpr int tile_dim_n = THREADS_PER_WARP * nvec;
// Number of nvec x nvec subtiles for each thread to
// load/store
constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
// Find tensor corresponding to block
int tensor_id = 0;
while (args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
const Type* input = reinterpret_cast<const Type*>(args.input_list[tensor_id]);
Type* output = reinterpret_cast<Type*>(args.output_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int padded_num_rows = args.padded_num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
// Find position of tile within tensor
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int tile_id = bid - args.block_range[tensor_id];
const int tile_id_m = tile_id / num_tiles_n;
const int tile_id_n = tile_id % num_tiles_n;
const int tile_row = tile_id_m * tile_dim_m;
const int tile_col = tile_id_n * tile_dim_n;
// Load input and store to registers
// Note: Each thread loads n_iterations subtiles, casts to output
// type, and transposes in registers.
Type local_zero = static_cast<Type>(0.f);
#pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
#pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2;
const int col = tile_col + j1 * nvec;
Vec local_input;
Vec local_output;
local_input.clear();
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row * row_length + col + j2];
}
}
}
#pragma unroll
for (int j2 = 0; j2 < nvec; ++j2) {
local_output.data.elt[j2] = local_input.data.elt[j2];
}
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
output[row * row_length + col + j2] = local_output.data.elt[j2];
}
}
} else if (row < padded_num_rows) {
// padding
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
output[row * row_length + col + j2] = local_zero;
}
}
}
}
}
}
} // namespace
void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> output_list,
const std::vector<int> padded_num_rows_list, cudaStream_t stream) {
// Check that number of tensors is valid
NVTE_CHECK(output_list.size() == input_list.size(),
"Number of input and output tensors must match");
if (input_list.empty()) {
return;
}
// Check that tensor properties are valid
DType type = input_list[0]->data.dtype;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = *input_list[tensor_id];
const auto& output = *output_list[tensor_id];
CheckInputTensor(input, "multi_padding_input_" + std::to_string(tensor_id));
CheckInputTensor(output, "multi_padding_output_" + std::to_string(tensor_id));
NVTE_CHECK(input.data.dtype == type, "Input tensor types do not match.");
NVTE_CHECK(output.data.dtype == type, "Output tensor types do not match.");
NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions.");
NVTE_CHECK(output.data.shape[0] == padded_num_rows_list[tensor_id],
"output tensor shape does not match padded input shape.");
}
// Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
// Add tensors to kernel argument struct
MultiPaddingArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.block_range[0] = 0;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
// Launch kernel if argument struct is full
if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type);
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_padding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
kernel_args.num_tensors = 0;
}
// Calculate number of thread blocks needed for tensor
const int num_rows = input_list[tensor_id]->data.shape[0];
const int padded_num_rows = padded_num_rows_list[tensor_id];
const int row_length = input_list[tensor_id]->data.shape[1];
const int num_tiles_m = (padded_num_rows + tile_dim_m - 1) / tile_dim_m;
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int num_tiles = num_tiles_m * num_tiles_n;
// Add tensor to kernel argument struct
const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input_list[tensor_id]->data.dptr);
kernel_args.output_list[pos] = output_list[tensor_id]->data.dptr;
kernel_args.num_rows_list[pos] = num_rows;
kernel_args.padded_num_rows_list[pos] = padded_num_rows;
kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles;
kernel_args.num_tensors++;
}
// Launch kernel
if (kernel_args.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type);
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_padding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
}
}
} // namespace transformer_engine
void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
const int* padded_num_rows_list, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_padding);
using namespace transformer_engine;
std::vector<Tensor*> input_list_, output_list_;
std::vector<int> padded_num_rows_list_;
for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i])));
output_list_.push_back(reinterpret_cast<Tensor*>(output_list[i]));
padded_num_rows_list_.push_back(padded_num_rows_list[i]);
}
multi_padding(input_list_, output_list_, padded_num_rows_list_, stream);
}
......@@ -67,6 +67,7 @@ from transformer_engine.pytorch.module import LayerNormMLP
from transformer_engine.pytorch.module import LayerNorm
from transformer_engine.pytorch.module import RMSNorm
from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding
from transformer_engine.pytorch.module import initialize_ub
from transformer_engine.pytorch.module import destroy_ub
from transformer_engine.pytorch.attention import DotProductAttention
......
......@@ -11,3 +11,4 @@ from .transpose import *
from .activation import *
from .normalization import *
from .cast import *
from .padding import *
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Python interface for transpose extensions"""
from typing import List, Tuple, Union
import torch
import transformer_engine_torch as tex
__all__ = [
"multi_padding_fused",
]
def multi_padding_fused(
inp: torch.Tensor,
row_list: List[int],
padded_row_list: List[int],
out: torch.Tensor,
) -> Union[Tuple[List[torch.Tensor], List[torch.Tensor]], None]:
"""Padding"""
tex.fused_multi_row_padding(
inp,
out,
row_list,
padded_row_list,
)
......@@ -28,6 +28,7 @@
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h>
#include <transformer_engine/padding.h>
#include <transformer_engine/permutation.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/rmsnorm.h>
......
......@@ -486,4 +486,12 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
float momentum, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale);
/***************************************************************************************************
* padding
**************************************************************************************************/
void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list) {
using namespace transformer_engine;
NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(),
"Number of input row list and padded row list must match.");
NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2.");
NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2.");
const int num_tensors = input_row_list.size();
// Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, output_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, output_shape_list;
std::vector<transformer_engine::DType> input_type_list;
void* d_input_ptr = reinterpret_cast<void*>(input.data_ptr());
void* d_output_ptr = reinterpret_cast<void*>(output.data_ptr());
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
input_dptr_list.push_back(d_input_ptr);
output_dptr_list.push_back(d_output_ptr);
// Move the input pointer to the next split.
char* input_char_ptr = reinterpret_cast<char*>(d_input_ptr);
const size_t input_dptr_offset =
input_row_list[tensor_id] * input.size(1) * input.element_size();
input_char_ptr += input_dptr_offset;
d_input_ptr = reinterpret_cast<void*>(input_char_ptr);
input_shape_list.push_back({input_row_list[tensor_id], static_cast<size_t>(input.size(1))});
input_type_list.push_back(GetTransformerEngineDType(input.scalar_type()));
// Move the output pointer to the next split.
char* output_char_ptr = reinterpret_cast<char*>(d_output_ptr);
const size_t output_dptr_offset =
padded_input_row_list[tensor_id] * output.size(1) * output.element_size();
output_char_ptr += output_dptr_offset;
d_output_ptr = reinterpret_cast<void*>(output_char_ptr);
output_shape_list.push_back(
{padded_input_row_list[tensor_id], static_cast<size_t>(output.size(1))});
}
// Construct TE tensors
std::vector<NVTETensor> nvte_input_list, nvte_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype) -> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype));
return tensor_wrappers.back().data();
};
std::vector<int> padded_num_rows_list;
for (size_t i = 0; i < input_dptr_list.size(); ++i) {
if (input_dptr_list[i] == nullptr || input_row_list[i] == 0) continue;
nvte_input_list.emplace_back(
make_tensor(input_dptr_list[i], input_shape_list[i], input_type_list[i]));
nvte_output_list.emplace_back(
make_tensor(output_dptr_list[i], output_shape_list[i], input_type_list[i]));
padded_num_rows_list.emplace_back(padded_input_row_list[i]);
}
// Check tensor lists
NVTE_CHECK(nvte_output_list.size() == nvte_input_list.size(),
"Number of input and output tensors must match");
NVTE_CHECK(padded_num_rows_list.size() == nvte_input_list.size() &&
"Number of input and padded row list must match");
// Launch TE kernel
nvte_multi_padding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(),
padded_num_rows_list.data(), at::cuda::getCurrentCUDAStream());
}
......@@ -152,7 +152,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction",
py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding",
py::call_guard<py::gil_scoped_release>());
// fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD",
py::call_guard<py::gil_scoped_release>());
......
......@@ -9,4 +9,6 @@ from .grouped_linear import GroupedLinear
from .layernorm_mlp import LayerNormMLP
from .layernorm import LayerNorm
from .rmsnorm import RMSNorm
from .fp8_padding import Fp8Padding
from .fp8_unpadding import Fp8Unpadding
from .base import initialize_ub, destroy_ub
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""FP8 Padding API"""
from typing import Union, List
import torch
from ..cpp_extensions import (
multi_padding_fused,
)
from ..jit import no_torch_dynamo
__all__ = ["Fp8Padding"]
class _Fp8Padding(torch.autograd.Function):
"""functional FP8 padding"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
m_splits: List[int],
padded_m_splits: List[int],
is_grad_enabled: bool,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = inp.shape[-1]
# Allocate cast and transpose output tensor
total_row = sum(padded_m_splits)
out = torch.empty([total_row, in_features], dtype=inp.dtype, device=inp.device)
multi_padding_fused(inp.view(-1, in_features), m_splits, padded_m_splits, out)
if is_grad_enabled:
ctx.m_splits = m_splits
ctx.padded_m_splits = padded_m_splits
ctx.requires_dgrad = inp.requires_grad
return out
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
grad_input = None
if ctx.requires_dgrad:
grad_output = grad_output.contiguous()
grad_output_mats = torch.split(
grad_output.view(-1, grad_output.shape[-1]), ctx.padded_m_splits
)
grad_input = torch.cat(
[
grad_output_mat[: ctx.m_splits[i]]
for i, grad_output_mat in enumerate(grad_output_mats)
],
dim=0,
)
return (grad_input, None, None, None)
class Fp8Padding(torch.nn.Module):
"""
Apply the padding for Grouped GEMM input.
Parameters
----------
num_gemms: int
number of GEMMs to be performed simutaneously.
"""
def __init__(
self,
num_gemms,
) -> None:
super().__init__()
self.num_gemms = num_gemms
@no_torch_dynamo()
def forward(
self,
inp: torch.Tensor,
m_splits: List[int],
) -> Union[torch.Tensor, List[int]]:
"""
Apply the padding to the input.
Parameters
----------
inp : torch.Tensor
Input tensor.
m_splits : List[int]
List of integers representing the split of the input tensor.
"""
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
# FP8 padding calculate
padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits]
if torch.is_grad_enabled():
fn = _Fp8Padding.apply
args = []
else:
fn = _Fp8Padding.forward
args = [None]
args += (
inp,
m_splits,
padded_m_splits,
torch.is_grad_enabled(),
)
out = fn(*args)
return out, padded_m_splits
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""FP8 Padding API"""
from typing import List
import torch
from ..cpp_extensions import (
multi_padding_fused,
)
from ..jit import no_torch_dynamo
__all__ = ["Fp8Unpadding"]
class _Fp8Unpadding(torch.autograd.Function):
"""functional FP8 unpadding"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
m_splits: List[int],
padded_m_splits: List[int],
is_grad_enabled: bool,
) -> torch.Tensor:
inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits)
out_ret = torch.cat(
[grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0
)
if is_grad_enabled:
ctx.m_splits = m_splits
ctx.padded_m_splits = padded_m_splits
ctx.requires_dgrad = inp.requires_grad
return out_ret
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
grad_input = None
if ctx.requires_dgrad:
grad_output = grad_output.contiguous()
in_features = grad_output.shape[-1]
# Allocate cast and transpose output tensor
total_row = sum(ctx.padded_m_splits)
grad_input = torch.empty(
[total_row, in_features], dtype=grad_output.dtype, device=grad_output.device
)
# FP8 pad input for forward, FP8 input transpose for backward wgrad
multi_padding_fused(
grad_output.view(-1, in_features), ctx.m_splits, ctx.padded_m_splits, grad_input
)
return (grad_input, None, None, None)
class Fp8Unpadding(torch.nn.Module):
"""
Apply the unpadding for Grouped GEMM input.
Parameters
----------
num_gemms: int
number of GEMMs to be performed simutaneously.
"""
def __init__(
self,
num_gemms,
) -> None:
super().__init__()
self.num_gemms = num_gemms
@no_torch_dynamo()
def forward(
self,
inp: torch.Tensor,
m_splits: List[int],
) -> torch.Tensor:
"""
Apply the unpadding to the input.
Parameters
----------
inp : torch.Tensor
Input tensor.
m_splits : List[int]
List of integers representing the split of the input tensor.
"""
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
# FP8 padding calculate
padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits]
if torch.is_grad_enabled():
fn = _Fp8Unpadding.apply
args = []
else:
fn = _Fp8Unpadding.forward
args = [None]
args += (
inp,
m_splits,
padded_m_splits,
torch.is_grad_enabled(),
)
out = fn(*args)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment