"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "fe8fad59c9add7f8aa841fed2a0b4087b931856f"
Unverified Commit 8e3561bf authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Update FP8 scale-inverse in kernels with FP8 output (#1083)



* Perform scale-inv update in cast-transpose kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Perform scale-inv update in cast and activation kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Perform sclae-inv update in LayerNorm and RMSNorm kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Perform scale-inv update after FP8 GEMMs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fuse casts and scale-inv updates in linear module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fuse casts and scale-inv updates in layernorm-linear module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Simplify kernel to update FP8 scale-inv
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix typos
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug amax update in layernorm kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug ONNX export

Use quantization scaling factor in ONNX quantize op.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @ptrendx
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Debug mismatched dtypes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 5d5fe819
......@@ -69,6 +69,8 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
......
......@@ -116,6 +116,8 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
......
......@@ -132,6 +132,8 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
......
......@@ -230,6 +230,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma)
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
......
......@@ -139,6 +139,10 @@ void performTest() {
output_c_list[tensor_id].amax(),
ref_amax_list[tensor_id],
atol_amax, rtol_amax);
compareResults("scale_inv",
output_c_list[tensor_id].scale_inv(),
1.f / output_c_list[tensor_id].scale(),
atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c",
......
......@@ -187,6 +187,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma)
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
......
......@@ -44,6 +44,7 @@ set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES
pycudnn.cpp
transformer_engine.cpp
common.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
......
......@@ -27,7 +27,8 @@ void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr), tot_elts, {},
reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), tot_elts, {},
stream);); // NOLINT(*)
); // NOLINT(*)
}
......@@ -50,7 +51,8 @@ void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr), tot_elts, {},
reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), tot_elts, {},
stream);); // NOLINT(*)
); // NOLINT(*)
}
......@@ -74,7 +76,8 @@ void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr), output->data.shape[0],
reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), output->data.shape[0],
output->data.shape[1], {},
stream);); // NOLINT(*)
); // NOLINT(*)
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transformer_engine.h>
#include "./common.h"
#include "./utils.cuh"
namespace transformer_engine {
namespace {
__global__ void __launch_bounds__(1)
update_tensor_scale_inv_kernel(const float* __restrict__ scale_ptr,
float* __restrict__ scale_inv_ptr) {
const float scale = scale_ptr == nullptr ? 1 : *scale_ptr;
reciprocal<float>(scale_inv_ptr, scale);
}
} // namespace
void update_tensor_scale_inv(Tensor* t, cudaStream_t stream) {
if (t->scale_inv.dptr != nullptr) {
update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<const float*>(t->scale.dptr), reinterpret_cast<float*>(t->scale_inv.dptr));
}
}
} // namespace transformer_engine
......@@ -262,6 +262,13 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
bool is_fp8_dtype(const DType t);
/*! \brief Update a tensor's FP8 scale-inverse
*
* The FP8 scale-inverse (dequantization scaling factor) is updated
* with the reciprocal of the FP8 scale (quantization scaling factor).
*/
void update_tensor_scale_inv(Tensor *t, cudaStream_t stream);
#define NVTE_API_CALL(api_name) \
transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name);
......
......@@ -269,6 +269,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
workspace, /* workspace */
workspaceSize, stream)); /* stream */
// Update FP8 scale-inv in output tensor
if (is_fp8_dtype(outputD->data.dtype)) {
update_tensor_scale_inv(outputD, stream);
}
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
......
......@@ -89,6 +89,9 @@ struct FwdParams : public ParamsBase {
// AMax output
void *amax;
// Inverse of scaling factor
void *scale_inv;
// Whether to compute scale and amax
bool fp8_out;
};
......
......@@ -196,6 +196,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;
......
......@@ -132,10 +132,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
}
}
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
......@@ -291,10 +299,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
// Finalize fp8 factors
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
......
......@@ -159,6 +159,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;
......
......@@ -125,10 +125,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
}
}
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
......@@ -267,10 +275,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
// Finalize fp8 factors
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
......
......@@ -101,14 +101,11 @@ struct KernelConfig {
};
template <size_t load_size, size_t store_size, typename IType, typename OType>
__global__ void __launch_bounds__(block_size)
cast_transpose_general_kernel(const IType *__restrict__ const input,
const CType *__restrict__ const noop,
OType *__restrict__ const output_c,
OType *__restrict__ const output_t,
const CType *__restrict__ const scale_ptr,
CType *__restrict__ const amax_ptr, const size_t row_length,
const size_t num_rows) {
__global__ void __launch_bounds__(block_size) cast_transpose_general_kernel(
const IType *__restrict__ const input, const CType *__restrict__ const noop,
OType *__restrict__ const output_c, OType *__restrict__ const output_t,
const CType *__restrict__ const scale_ptr, CType *__restrict__ const amax_ptr,
CType *__restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes
......@@ -207,9 +204,15 @@ __global__ void __launch_bounds__(block_size)
if (amax_ptr != nullptr) {
amax = reduce_max<warps_per_tile>(amax, tidy);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(amax_ptr, amax);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) {
reciprocal<CType>(scale_inv_ptr, scale);
}
}
} // namespace
......@@ -255,6 +258,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
"Cast and transposed outputs need to share amax tensor.");
NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr,
"Cast and transposed outputs need to share scale tensor.");
NVTE_CHECK(cast_output.scale_inv.dptr == transposed_output.scale_inv.dptr,
"Cast and transposed outputs need to share scale-inverse tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, InputType,
......@@ -324,7 +329,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows);
static_cast<CType *>(cast_output.amax.dptr),
static_cast<CType *>(cast_output.scale_inv.dptr), row_length,
num_rows);
} else { // Statically-compiled general kernel
constexpr size_t load_size = 4;
constexpr size_t store_size = 4;
......@@ -339,7 +346,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows);
static_cast<CType *>(cast_output.amax.dptr),
static_cast<CType *>(cast_output.scale_inv.dptr), row_length, num_rows);
}); // NOLINT(*)
); // NOLINT(*)
}
......
......@@ -433,15 +433,19 @@ __global__ void __launch_bounds__(cast_transpose_num_threads)
}
}
/* warp tile amax reduce*/
amax = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(amax, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) {
// Reduce amax over block
if (param.amax != nullptr) {
amax = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(amax, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(param.amax, amax);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && param.scale_inv != nullptr) {
reciprocal<CType>(param.scale_inv, scale);
}
}
static const char *ActTypeToString[] = {
......@@ -870,17 +874,18 @@ __global__ void __launch_bounds__(cast_transpose_num_threads)
__syncthreads();
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) {
// Reduce amax over block
if (amax != nullptr) {
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(amax, max);
}
if (scale_inv != nullptr) {
reciprocal<float>(scale_inv, scale);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<CType>(scale_inv, scale);
}
}
......@@ -1079,17 +1084,18 @@ __global__ void __launch_bounds__(cast_transpose_num_threads)
__syncthreads();
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) {
// Reduce amax over block
if (amax != nullptr) {
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(amax, max);
}
if (scale_inv != nullptr) {
reciprocal<float>(scale_inv, scale);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<CType>(scale_inv, scale);
}
}
......
......@@ -36,6 +36,8 @@ struct MultiCastTransposeArgs {
void* scale_list[kMaxTensorsPerKernel];
// (output) AMAX's of input tensors
void* amax_list[kMaxTensorsPerKernel];
// (output) Inverse of scaling factor for output tensors
void* scale_inv_list[kMaxTensorsPerKernel];
// Input matrix heights
int num_rows_list[kMaxTensorsPerKernel];
// Input matrix widths
......@@ -82,7 +84,8 @@ __global__ void __launch_bounds__(threads_per_block)
OType* output_t = reinterpret_cast<OType*>(args.output_t_list[tensor_id]);
const CType* scale_ptr = reinterpret_cast<CType*>(args.scale_list[tensor_id]);
const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr;
CType* amax = reinterpret_cast<CType*>(args.amax_list[tensor_id]);
CType* amax_ptr = reinterpret_cast<CType*>(args.amax_list[tensor_id]);
CType* scale_inv_ptr = reinterpret_cast<CType*>(args.scale_inv_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
......@@ -183,7 +186,10 @@ __global__ void __launch_bounds__(threads_per_block)
local_amax = reduce_max<n_warps_per_tile>(local_amax, tidy);
if (tid == 0) {
static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) atomicMaxFloat(amax, local_amax);
if (amax_ptr != nullptr) atomicMaxFloat(amax_ptr, local_amax);
}
if (tile_id == 0 && tid == 0 && scale_inv_ptr != nullptr) {
reciprocal<CType>(scale_inv_ptr, scale);
}
}
......@@ -285,6 +291,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr;
kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr;
kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr;
kernel_args.scale_inv_list[pos] = cast_output_list[tensor_id]->scale_inv.dptr;
kernel_args.num_rows_list[pos] = num_rows;
kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles;
......
......@@ -25,7 +25,7 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
const IType* __restrict__ const input, const CType* __restrict__ const noop,
OType* __restrict__ const output_c, OType* __restrict__ const output_t,
const CType* __restrict__ const scale_ptr, CType* __restrict__ const amax_ptr,
const size_t row_length, const size_t num_rows) {
CType* __restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes
......@@ -121,4 +121,9 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
atomicMaxFloat(amax_ptr, amax);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) {
reciprocal<CType>(scale_inv_ptr, scale);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment