Unverified Commit 215db88d authored by Ruibin Cheung's avatar Ruibin Cheung Committed by GitHub
Browse files

[PyTorch] Implement Fp8 padding and unpadding module (#1129)



* [TE/PyTorch][MoE] Add FP8 padding and unpadding module 

 1. Add multi-tensor padding kernel for FP8 with padding size = 16.
 2. Add FP8Padding and Fp8Unpadding module
 3. Add Padded GroupedLinear unit tests

---------
Signed-off-by: default avatarbeinggod <zhangruibin@01.ai>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 454e3895
...@@ -13,6 +13,7 @@ add_executable(test_operator ...@@ -13,6 +13,7 @@ add_executable(test_operator
test_layernorm.cu test_layernorm.cu
test_rmsnorm.cu test_rmsnorm.cu
test_multi_cast_transpose.cu test_multi_cast_transpose.cu
test_multi_padding.cu
test_causal_softmax.cu test_causal_softmax.cu
../test_common.cu) ../test_common.cu)
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include <cstdio>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/padding.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref(const std::vector<std::vector<InputType>>& input_list,
std::vector<std::vector<OutputType>>& output_list,
const std::vector<size_t>& height_list,
const std::vector<size_t>& width_list,
const std::vector<int>& padded_height_list) {
using compute_t = float;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = input_list[tensor_id];
auto& output = output_list[tensor_id];
const size_t height = height_list[tensor_id];
const size_t width = width_list[tensor_id];
const size_t padded_height = padded_height_list[tensor_id];
for (size_t i = 0; i < padded_height; ++i) {
if (i < height) {
for (size_t j = 0; j < width; ++j) {
const compute_t x = static_cast<compute_t>(input[i * width + j]);
const OutputType y = static_cast<OutputType>(x);
output[i * width + j] = y;
}
} else {
for (size_t j = 0; j < width; ++j) {
output[i * width + j] = static_cast<OutputType>(0.f);
}
}
}
}
}
template <typename InputType, typename OutputType>
void performTest() {
using namespace test;
const DType itype = TypeInfo<InputType>::dtype;
const DType otype = TypeInfo<OutputType>::dtype;
const std::vector<std::pair<size_t, size_t>> tensor_dims = {{1,1},
{1,768},
{768,1},
{768,768},
{43,43},
{43,256},
{256,43},
{256,256}};
const size_t num_tensors = tensor_dims.size();
constexpr int align = 16;
// Buffers for Transformer Engine implementation
std::vector<Tensor> input_list, output_list, output_t_list;
// Buffers for reference implementation
std::vector<std::vector<InputType>> ref_input_list;
std::vector<std::vector<OutputType>> ref_output_list;
std::vector<size_t> ref_height_list(num_tensors), ref_width_list(num_tensors);
std::vector<int> ref_padded_height_list(num_tensors);
// Initialize buffers
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
const size_t height = tensor_dims[tensor_id].first;
const size_t width = tensor_dims[tensor_id].second;
const size_t padded_height = (height + align - 1) / align * align;
input_list.emplace_back(Tensor({ height, width }, itype));
output_list.emplace_back(Tensor({ padded_height, width }, otype));
auto& input = input_list.back();
auto& output = output_list.back();
fillUniform(&input);
setRandomScale(&output);
ref_input_list.emplace_back(height*width);
ref_output_list.emplace_back(padded_height*width);
std::copy(input.cpu_dptr<InputType>(),
input.cpu_dptr<InputType>() + height * width,
ref_input_list.back().begin());
ref_height_list[tensor_id] = height;
ref_width_list[tensor_id] = width;
ref_padded_height_list[tensor_id] = padded_height;
}
// Transformer Engine implementation
auto make_nvte_vector = [](std::vector<Tensor>& tensor_list)
-> std::vector<NVTETensor> {
std::vector<NVTETensor> nvte_tensor_list;
for (auto& tensor : tensor_list) {
nvte_tensor_list.emplace_back(tensor.data());
}
return nvte_tensor_list;
};
nvte_multi_padding(num_tensors,
make_nvte_vector(input_list).data(),
make_nvte_vector(output_list).data(),
ref_padded_height_list.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
// Reference implementation
compute_ref<InputType, OutputType>(ref_input_list,
ref_output_list,
ref_height_list,
ref_width_list,
ref_padded_height_list);
// Check correctness
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
auto [atol, rtol] = getTolerances(otype);
compareResults("output",
output_list[tensor_id],
ref_output_list[tensor_id].data(),
atol, rtol);
}
}
} // namespace
class MultiPaddingTestSuite
: public ::testing::TestWithParam<
transformer_engine::DType> {};
TEST_P(MultiPaddingTestSuite, TestMultiPaddingTranspose) {
using namespace transformer_engine;
using namespace test;
const DType input_type = GetParam();
const DType output_type = input_type;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>();
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MultiPaddingTestSuite,
::testing::ValuesIn(test::all_fp_types),
[](const testing::TestParamInfo<MultiPaddingTestSuite::ParamType>& info) {
std::string name = test::typeName(info.param);
return name;
});
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
import pytest import pytest
import copy import copy
import random
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -30,6 +31,8 @@ from transformer_engine.pytorch import ( ...@@ -30,6 +31,8 @@ from transformer_engine.pytorch import (
TransformerLayer, TransformerLayer,
LayerNorm, LayerNorm,
InferenceParams, InferenceParams,
Fp8Padding,
Fp8Unpadding,
) )
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm
...@@ -354,6 +357,40 @@ class TorchSquaredRELU(nn.Module): ...@@ -354,6 +357,40 @@ class TorchSquaredRELU(nn.Module):
return (input > 0) * input * input return (input > 0) * input * input
class TorchGroupedLinearWithPadding(nn.Module):
def __init__(
self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8
) -> None:
super().__init__()
self.padding = Fp8Padding(num_gemms)
self.linear_fn = GroupedLinear(
num_gemms,
in_features,
out_features,
bias=bias,
params_dtype=params_dtype,
parallel_mode=parallel_mode,
device="cuda",
)
self.unpadding = Fp8Unpadding(num_gemms)
self.fp8 = fp8
def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor:
if self.fp8:
orig_m_splits = m_splits
inp, m_splits = self.padding(inp, m_splits)
out = self.linear_fn(inp, m_splits)
if self.fp8:
out = self.unpadding(out, orig_m_splits)
return out
_supported_act = { _supported_act = {
"geglu": nn.GELU(approximate="tanh"), "geglu": nn.GELU(approximate="tanh"),
"gelu": nn.GELU(approximate="tanh"), "gelu": nn.GELU(approximate="tanh"),
...@@ -1328,6 +1365,158 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode): ...@@ -1328,6 +1365,158 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
) )
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):
def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
"""Padding tensor shapes to multiples of 16."""
padded_tokens_per_expert = [
(num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert
]
hidden_states = torch.split(hidden_states, tokens_per_expert)
padded_hidden_states = []
for hidden_state, actual_num_tokens, padded_num_tokens in zip(
hidden_states, tokens_per_expert, padded_tokens_per_expert
):
padded_hidden_states.append(hidden_state)
if padded_num_tokens > actual_num_tokens:
pad_tensor = torch.zeros(
padded_num_tokens - actual_num_tokens,
hidden_state.shape[1],
dtype=hidden_state.dtype,
device=hidden_state.device,
)
padded_hidden_states.append(pad_tensor)
padded_hidden_states = torch.cat(padded_hidden_states, dim=0)
return padded_hidden_states, padded_tokens_per_expert
def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert):
inputmats = torch.split(
padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert
)
hidden_states = torch.cat(
[
grad_output_mat[: actual_tokens_per_expert[i]]
for i, grad_output_mat in enumerate(inputmats)
],
dim=0,
)
return hidden_states
def _generate_random_numbers(n, total_sum):
if n <= 0:
return []
# reset seed
random.seed(seed)
breaks = sorted(random.sample(range(1, total_sum), n - 1))
random_numbers = (
[breaks[0]]
+ [breaks[i] - breaks[i - 1] for i in range(1, n - 1)]
+ [total_sum - breaks[-1]]
)
return random_numbers
reset_rng_states()
if fp8:
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len * bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs)
with fp8_autocast(enabled=fp8):
if isinstance(block, TorchGroupedLinearWithPadding):
out = block(inp_hidden_states, m_splits)
else:
if fp8:
padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8(
inp_hidden_states, m_splits
)
padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits)
out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits)
else:
out = block(inp_hidden_states, m_splits)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", [True])
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy(
dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params):
grouped_linear = TorchGroupedLinearWithPadding(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
fp8=fp8,
).eval()
with fp8_model_init(enabled=fp8 and fp8_model_params):
ref_grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
).eval()
# Share params
with torch.no_grad():
inner_grouped_linear = grouped_linear.linear_fn
for i in range(num_gemms):
setattr(
ref_grouped_linear,
f"weight{i}",
Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
)
outputs = _test_padding_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, fp8
)
outputs_ref = _test_padding_grouped_linear_accuracy(
ref_grouped_linear, num_gemms, bs, dtype, config, fp8
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
reset_rng_states() reset_rng_states()
......
...@@ -71,6 +71,7 @@ list(APPEND transformer_engine_SOURCES ...@@ -71,6 +71,7 @@ list(APPEND transformer_engine_SOURCES
rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
rmsnorm/rmsnorm_fwd_cuda_kernel.cu rmsnorm/rmsnorm_fwd_cuda_kernel.cu
util/cast.cu util/cast.cu
util/padding.cu
util/cuda_driver.cpp util/cuda_driver.cpp
util/cuda_runtime.cpp util/cuda_runtime.cpp
util/rtc.cpp util/rtc.cpp
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file padding.h
* \brief Functions handling padding.
*/
#ifndef TRANSFORMER_ENGINE_PADDING_H_
#define TRANSFORMER_ENGINE_PADDING_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Padding multiple tensors.
*
* NOTE: Padding mode only support bottom.
*
* For example, 3x3 matrix pad to 4x3 matrix.
*
* source
* | 1 | 2 | 3 |
* | 4 | 5 | 6 |
* | 7 | 8 | 9 |
*
* destination
* | 1 | 2 | 3 |
* | 4 | 5 | 6 |
* | 7 | 8 | 9 |
* | 0 | 0 | 0 |
*
* \param[in] num_tensors Number of tensors.
* \param[in] input_list List of 2D input tensors.
* \param[in,out] output_list List of padded tensors. Dimensions
* match tensors in input_list.
* \param[in] padded_num_rows_list List of padded num rows corresponding to input tensors.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
const int* padded_num_rows_list, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_PADDING_H_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/padding.h>
#include <cfloat>
#include <iostream>
#include <vector>
#include "../common.h"
#include "../utils.cuh"
namespace transformer_engine {
namespace {
// Parameters to tune
constexpr int n_warps_per_tile = 4;
constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile;
constexpr int desired_load_store_size = 8;
constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB
struct MultiPaddingArgs {
// (input) Data buffers for input tensors
void* input_list[kMaxTensorsPerKernel];
// (output) Data buffers for cast output tensors
void* output_list[kMaxTensorsPerKernel];
// Input matrix heights
int num_rows_list[kMaxTensorsPerKernel];
// Input matrix heights (padded)
int padded_num_rows_list[kMaxTensorsPerKernel];
// Input matrix widths
int row_length_list[kMaxTensorsPerKernel];
// tensor
int block_range[kMaxTensorsPerKernel + 1];
// Number of tensors being processed by kernel
int num_tensors;
};
template <int nvec, typename Type>
__global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiPaddingArgs args) {
using Vec = Vec<Type, nvec>;
// Thread indices
// Note: Block is interpreted as a warp_size x num_warps grid
constexpr int bdimx = THREADS_PER_WARP;
constexpr int bdimy = n_warps_per_tile;
const int tid = threadIdx.x;
const int tidx = tid % bdimx;
const int tidy = tid / bdimx;
const int bid = blockIdx.x;
// Input tensors are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
constexpr int tile_dim_m = THREADS_PER_WARP * nvec;
constexpr int tile_dim_n = THREADS_PER_WARP * nvec;
// Number of nvec x nvec subtiles for each thread to
// load/store
constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
// Find tensor corresponding to block
int tensor_id = 0;
while (args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
const Type* input = reinterpret_cast<const Type*>(args.input_list[tensor_id]);
Type* output = reinterpret_cast<Type*>(args.output_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int padded_num_rows = args.padded_num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
// Find position of tile within tensor
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int tile_id = bid - args.block_range[tensor_id];
const int tile_id_m = tile_id / num_tiles_n;
const int tile_id_n = tile_id % num_tiles_n;
const int tile_row = tile_id_m * tile_dim_m;
const int tile_col = tile_id_n * tile_dim_n;
// Load input and store to registers
// Note: Each thread loads n_iterations subtiles, casts to output
// type, and transposes in registers.
Type local_zero = static_cast<Type>(0.f);
#pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
#pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2;
const int col = tile_col + j1 * nvec;
Vec local_input;
Vec local_output;
local_input.clear();
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row * row_length + col + j2];
}
}
}
#pragma unroll
for (int j2 = 0; j2 < nvec; ++j2) {
local_output.data.elt[j2] = local_input.data.elt[j2];
}
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
output[row * row_length + col + j2] = local_output.data.elt[j2];
}
}
} else if (row < padded_num_rows) {
// padding
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
output[row * row_length + col + j2] = local_zero;
}
}
}
}
}
}
} // namespace
void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> output_list,
const std::vector<int> padded_num_rows_list, cudaStream_t stream) {
// Check that number of tensors is valid
NVTE_CHECK(output_list.size() == input_list.size(),
"Number of input and output tensors must match");
if (input_list.empty()) {
return;
}
// Check that tensor properties are valid
DType type = input_list[0]->data.dtype;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = *input_list[tensor_id];
const auto& output = *output_list[tensor_id];
CheckInputTensor(input, "multi_padding_input_" + std::to_string(tensor_id));
CheckInputTensor(output, "multi_padding_output_" + std::to_string(tensor_id));
NVTE_CHECK(input.data.dtype == type, "Input tensor types do not match.");
NVTE_CHECK(output.data.dtype == type, "Output tensor types do not match.");
NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions.");
NVTE_CHECK(output.data.shape[0] == padded_num_rows_list[tensor_id],
"output tensor shape does not match padded input shape.");
}
// Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
// Add tensors to kernel argument struct
MultiPaddingArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.block_range[0] = 0;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
// Launch kernel if argument struct is full
if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type);
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_padding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
kernel_args.num_tensors = 0;
}
// Calculate number of thread blocks needed for tensor
const int num_rows = input_list[tensor_id]->data.shape[0];
const int padded_num_rows = padded_num_rows_list[tensor_id];
const int row_length = input_list[tensor_id]->data.shape[1];
const int num_tiles_m = (padded_num_rows + tile_dim_m - 1) / tile_dim_m;
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int num_tiles = num_tiles_m * num_tiles_n;
// Add tensor to kernel argument struct
const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input_list[tensor_id]->data.dptr);
kernel_args.output_list[pos] = output_list[tensor_id]->data.dptr;
kernel_args.num_rows_list[pos] = num_rows;
kernel_args.padded_num_rows_list[pos] = padded_num_rows;
kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles;
kernel_args.num_tensors++;
}
// Launch kernel
if (kernel_args.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type);
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_padding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
}
}
} // namespace transformer_engine
void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
const int* padded_num_rows_list, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_padding);
using namespace transformer_engine;
std::vector<Tensor*> input_list_, output_list_;
std::vector<int> padded_num_rows_list_;
for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i])));
output_list_.push_back(reinterpret_cast<Tensor*>(output_list[i]));
padded_num_rows_list_.push_back(padded_num_rows_list[i]);
}
multi_padding(input_list_, output_list_, padded_num_rows_list_, stream);
}
...@@ -67,6 +67,7 @@ from transformer_engine.pytorch.module import LayerNormMLP ...@@ -67,6 +67,7 @@ from transformer_engine.pytorch.module import LayerNormMLP
from transformer_engine.pytorch.module import LayerNorm from transformer_engine.pytorch.module import LayerNorm
from transformer_engine.pytorch.module import RMSNorm from transformer_engine.pytorch.module import RMSNorm
from transformer_engine.pytorch.module import GroupedLinear from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding
from transformer_engine.pytorch.module import initialize_ub from transformer_engine.pytorch.module import initialize_ub
from transformer_engine.pytorch.module import destroy_ub from transformer_engine.pytorch.module import destroy_ub
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
......
...@@ -11,3 +11,4 @@ from .transpose import * ...@@ -11,3 +11,4 @@ from .transpose import *
from .activation import * from .activation import *
from .normalization import * from .normalization import *
from .cast import * from .cast import *
from .padding import *
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Python interface for transpose extensions"""
from typing import List, Tuple, Union
import torch
import transformer_engine_torch as tex
__all__ = [
"multi_padding_fused",
]
def multi_padding_fused(
inp: torch.Tensor,
row_list: List[int],
padded_row_list: List[int],
out: torch.Tensor,
) -> Union[Tuple[List[torch.Tensor], List[torch.Tensor]], None]:
"""Padding"""
tex.fused_multi_row_padding(
inp,
out,
row_list,
padded_row_list,
)
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <transformer_engine/fused_rope.h> #include <transformer_engine/fused_rope.h>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h> #include <transformer_engine/layer_norm.h>
#include <transformer_engine/padding.h>
#include <transformer_engine/permutation.h> #include <transformer_engine/permutation.h>
#include <transformer_engine/recipe.h> #include <transformer_engine/recipe.h>
#include <transformer_engine/rmsnorm.h> #include <transformer_engine/rmsnorm.h>
......
...@@ -486,4 +486,12 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -486,4 +486,12 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
float momentum, float dampening, float lr, bool nesterov, bool first_run, float momentum, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale); bool wd_after_momentum, float scale);
/***************************************************************************************************
* padding
**************************************************************************************************/
void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list) {
using namespace transformer_engine;
NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(),
"Number of input row list and padded row list must match.");
NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2.");
NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2.");
const int num_tensors = input_row_list.size();
// Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, output_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, output_shape_list;
std::vector<transformer_engine::DType> input_type_list;
void* d_input_ptr = reinterpret_cast<void*>(input.data_ptr());
void* d_output_ptr = reinterpret_cast<void*>(output.data_ptr());
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
input_dptr_list.push_back(d_input_ptr);
output_dptr_list.push_back(d_output_ptr);
// Move the input pointer to the next split.
char* input_char_ptr = reinterpret_cast<char*>(d_input_ptr);
const size_t input_dptr_offset =
input_row_list[tensor_id] * input.size(1) * input.element_size();
input_char_ptr += input_dptr_offset;
d_input_ptr = reinterpret_cast<void*>(input_char_ptr);
input_shape_list.push_back({input_row_list[tensor_id], static_cast<size_t>(input.size(1))});
input_type_list.push_back(GetTransformerEngineDType(input.scalar_type()));
// Move the output pointer to the next split.
char* output_char_ptr = reinterpret_cast<char*>(d_output_ptr);
const size_t output_dptr_offset =
padded_input_row_list[tensor_id] * output.size(1) * output.element_size();
output_char_ptr += output_dptr_offset;
d_output_ptr = reinterpret_cast<void*>(output_char_ptr);
output_shape_list.push_back(
{padded_input_row_list[tensor_id], static_cast<size_t>(output.size(1))});
}
// Construct TE tensors
std::vector<NVTETensor> nvte_input_list, nvte_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype) -> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype));
return tensor_wrappers.back().data();
};
std::vector<int> padded_num_rows_list;
for (size_t i = 0; i < input_dptr_list.size(); ++i) {
if (input_dptr_list[i] == nullptr || input_row_list[i] == 0) continue;
nvte_input_list.emplace_back(
make_tensor(input_dptr_list[i], input_shape_list[i], input_type_list[i]));
nvte_output_list.emplace_back(
make_tensor(output_dptr_list[i], output_shape_list[i], input_type_list[i]));
padded_num_rows_list.emplace_back(padded_input_row_list[i]);
}
// Check tensor lists
NVTE_CHECK(nvte_output_list.size() == nvte_input_list.size(),
"Number of input and output tensors must match");
NVTE_CHECK(padded_num_rows_list.size() == nvte_input_list.size() &&
"Number of input and padded row list must match");
// Launch TE kernel
nvte_multi_padding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(),
padded_num_rows_list.data(), at::cuda::getCurrentCUDAStream());
}
...@@ -152,7 +152,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -152,7 +152,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction, m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction", "Update amax history and FP8 scale/scale_inv after reduction",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding",
py::call_guard<py::gil_scoped_release>());
// fused apply rope // fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD", m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
......
...@@ -9,4 +9,6 @@ from .grouped_linear import GroupedLinear ...@@ -9,4 +9,6 @@ from .grouped_linear import GroupedLinear
from .layernorm_mlp import LayerNormMLP from .layernorm_mlp import LayerNormMLP
from .layernorm import LayerNorm from .layernorm import LayerNorm
from .rmsnorm import RMSNorm from .rmsnorm import RMSNorm
from .fp8_padding import Fp8Padding
from .fp8_unpadding import Fp8Unpadding
from .base import initialize_ub, destroy_ub from .base import initialize_ub, destroy_ub
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""FP8 Padding API"""
from typing import Union, List
import torch
from ..cpp_extensions import (
multi_padding_fused,
)
from ..jit import no_torch_dynamo
__all__ = ["Fp8Padding"]
class _Fp8Padding(torch.autograd.Function):
"""functional FP8 padding"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
m_splits: List[int],
padded_m_splits: List[int],
is_grad_enabled: bool,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = inp.shape[-1]
# Allocate cast and transpose output tensor
total_row = sum(padded_m_splits)
out = torch.empty([total_row, in_features], dtype=inp.dtype, device=inp.device)
multi_padding_fused(inp.view(-1, in_features), m_splits, padded_m_splits, out)
if is_grad_enabled:
ctx.m_splits = m_splits
ctx.padded_m_splits = padded_m_splits
ctx.requires_dgrad = inp.requires_grad
return out
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
grad_input = None
if ctx.requires_dgrad:
grad_output = grad_output.contiguous()
grad_output_mats = torch.split(
grad_output.view(-1, grad_output.shape[-1]), ctx.padded_m_splits
)
grad_input = torch.cat(
[
grad_output_mat[: ctx.m_splits[i]]
for i, grad_output_mat in enumerate(grad_output_mats)
],
dim=0,
)
return (grad_input, None, None, None)
class Fp8Padding(torch.nn.Module):
"""
Apply the padding for Grouped GEMM input.
Parameters
----------
num_gemms: int
number of GEMMs to be performed simutaneously.
"""
def __init__(
self,
num_gemms,
) -> None:
super().__init__()
self.num_gemms = num_gemms
@no_torch_dynamo()
def forward(
self,
inp: torch.Tensor,
m_splits: List[int],
) -> Union[torch.Tensor, List[int]]:
"""
Apply the padding to the input.
Parameters
----------
inp : torch.Tensor
Input tensor.
m_splits : List[int]
List of integers representing the split of the input tensor.
"""
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
# FP8 padding calculate
padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits]
if torch.is_grad_enabled():
fn = _Fp8Padding.apply
args = []
else:
fn = _Fp8Padding.forward
args = [None]
args += (
inp,
m_splits,
padded_m_splits,
torch.is_grad_enabled(),
)
out = fn(*args)
return out, padded_m_splits
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""FP8 Padding API"""
from typing import List
import torch
from ..cpp_extensions import (
multi_padding_fused,
)
from ..jit import no_torch_dynamo
__all__ = ["Fp8Unpadding"]
class _Fp8Unpadding(torch.autograd.Function):
"""functional FP8 unpadding"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
m_splits: List[int],
padded_m_splits: List[int],
is_grad_enabled: bool,
) -> torch.Tensor:
inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits)
out_ret = torch.cat(
[grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0
)
if is_grad_enabled:
ctx.m_splits = m_splits
ctx.padded_m_splits = padded_m_splits
ctx.requires_dgrad = inp.requires_grad
return out_ret
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
grad_input = None
if ctx.requires_dgrad:
grad_output = grad_output.contiguous()
in_features = grad_output.shape[-1]
# Allocate cast and transpose output tensor
total_row = sum(ctx.padded_m_splits)
grad_input = torch.empty(
[total_row, in_features], dtype=grad_output.dtype, device=grad_output.device
)
# FP8 pad input for forward, FP8 input transpose for backward wgrad
multi_padding_fused(
grad_output.view(-1, in_features), ctx.m_splits, ctx.padded_m_splits, grad_input
)
return (grad_input, None, None, None)
class Fp8Unpadding(torch.nn.Module):
"""
Apply the unpadding for Grouped GEMM input.
Parameters
----------
num_gemms: int
number of GEMMs to be performed simutaneously.
"""
def __init__(
self,
num_gemms,
) -> None:
super().__init__()
self.num_gemms = num_gemms
@no_torch_dynamo()
def forward(
self,
inp: torch.Tensor,
m_splits: List[int],
) -> torch.Tensor:
"""
Apply the unpadding to the input.
Parameters
----------
inp : torch.Tensor
Input tensor.
m_splits : List[int]
List of integers representing the split of the input tensor.
"""
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
# FP8 padding calculate
padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits]
if torch.is_grad_enabled():
fn = _Fp8Unpadding.apply
args = []
else:
fn = _Fp8Unpadding.forward
args = [None]
args += (
inp,
m_splits,
padded_m_splits,
torch.is_grad_enabled(),
)
out = fn(*args)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment