"vscode:/vscode.git/clone" did not exist on "b0846aaf37a271d56019b42121ebcea0ac42e8c1"
Unverified Commit 8e3561bf authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Update FP8 scale-inverse in kernels with FP8 output (#1083)



* Perform scale-inv update in cast-transpose kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Perform scale-inv update in cast and activation kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Perform sclae-inv update in LayerNorm and RMSNorm kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Perform scale-inv update after FP8 GEMMs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fuse casts and scale-inv updates in linear module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fuse casts and scale-inv updates in layernorm-linear module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Simplify kernel to update FP8 scale-inv
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix typos
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug amax update in layernorm kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug ONNX export

Use quantization scaling factor in ONNX quantize op.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @ptrendx
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Debug mismatched dtypes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 5d5fe819
...@@ -69,6 +69,8 @@ void performTest(const size_t N, const size_t H) { ...@@ -69,6 +69,8 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) { if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
} }
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
......
...@@ -116,6 +116,8 @@ void performTest(const size_t N, const size_t H) { ...@@ -116,6 +116,8 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) { if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
} }
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
......
...@@ -132,6 +132,8 @@ void performTest(const size_t N, const size_t H) { ...@@ -132,6 +132,8 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) { if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
} }
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
......
...@@ -230,6 +230,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) ...@@ -230,6 +230,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma)
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) { if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
} }
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
......
...@@ -139,6 +139,10 @@ void performTest() { ...@@ -139,6 +139,10 @@ void performTest() {
output_c_list[tensor_id].amax(), output_c_list[tensor_id].amax(),
ref_amax_list[tensor_id], ref_amax_list[tensor_id],
atol_amax, rtol_amax); atol_amax, rtol_amax);
compareResults("scale_inv",
output_c_list[tensor_id].scale_inv(),
1.f / output_c_list[tensor_id].scale(),
atol_amax, rtol_amax);
} }
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", compareResults("output_c",
......
...@@ -187,6 +187,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) ...@@ -187,6 +187,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma)
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) { if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
} }
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
......
...@@ -44,6 +44,7 @@ set(transformer_engine_SOURCES) ...@@ -44,6 +44,7 @@ set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES list(APPEND transformer_engine_SOURCES
pycudnn.cpp pycudnn.cpp
transformer_engine.cpp transformer_engine.cpp
common.cu
transpose/cast_transpose.cu transpose/cast_transpose.cu
transpose/transpose.cu transpose/transpose.cu
transpose/cast_transpose_fusion.cu transpose/cast_transpose_fusion.cu
......
...@@ -27,7 +27,8 @@ void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { ...@@ -27,7 +27,8 @@ void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<const IType *>(input.data.dptr), reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr), reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr), reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr), tot_elts, {}, reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), tot_elts, {},
stream);); // NOLINT(*) stream);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
...@@ -50,7 +51,8 @@ void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream ...@@ -50,7 +51,8 @@ void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream
reinterpret_cast<const IType *>(input.data.dptr), reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr), reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr), reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr), tot_elts, {}, reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), tot_elts, {},
stream);); // NOLINT(*) stream);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
...@@ -74,7 +76,8 @@ void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { ...@@ -74,7 +76,8 @@ void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<const IType *>(input.data.dptr), reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr), reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr), reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr), output->data.shape[0], reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), output->data.shape[0],
output->data.shape[1], {}, output->data.shape[1], {},
stream);); // NOLINT(*) stream);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transformer_engine.h>
#include "./common.h"
#include "./utils.cuh"
namespace transformer_engine {
namespace {
__global__ void __launch_bounds__(1)
update_tensor_scale_inv_kernel(const float* __restrict__ scale_ptr,
float* __restrict__ scale_inv_ptr) {
const float scale = scale_ptr == nullptr ? 1 : *scale_ptr;
reciprocal<float>(scale_inv_ptr, scale);
}
} // namespace
void update_tensor_scale_inv(Tensor* t, cudaStream_t stream) {
if (t->scale_inv.dptr != nullptr) {
update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<const float*>(t->scale.dptr), reinterpret_cast<float*>(t->scale_inv.dptr));
}
}
} // namespace transformer_engine
...@@ -262,6 +262,13 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt ...@@ -262,6 +262,13 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
bool is_fp8_dtype(const DType t); bool is_fp8_dtype(const DType t);
/*! \brief Update a tensor's FP8 scale-inverse
*
* The FP8 scale-inverse (dequantization scaling factor) is updated
* with the reciprocal of the FP8 scale (quantization scaling factor).
*/
void update_tensor_scale_inv(Tensor *t, cudaStream_t stream);
#define NVTE_API_CALL(api_name) \ #define NVTE_API_CALL(api_name) \
transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name); transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name);
......
...@@ -269,6 +269,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -269,6 +269,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
workspace, /* workspace */ workspace, /* workspace */
workspaceSize, stream)); /* stream */ workspaceSize, stream)); /* stream */
// Update FP8 scale-inv in output tensor
if (is_fp8_dtype(outputD->data.dtype)) {
update_tensor_scale_inv(outputD, stream);
}
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
......
...@@ -89,6 +89,9 @@ struct FwdParams : public ParamsBase { ...@@ -89,6 +89,9 @@ struct FwdParams : public ParamsBase {
// AMax output // AMax output
void *amax; void *amax;
// Inverse of scaling factor
void *scale_inv;
// Whether to compute scale and amax // Whether to compute scale and amax
bool fp8_out; bool fp8_out;
}; };
......
...@@ -196,6 +196,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -196,6 +196,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
params.epsilon = epsilon; params.epsilon = epsilon;
params.amax = z->amax.dptr; params.amax = z->amax.dptr;
params.scale = z->scale.dptr; params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out; params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma; params.zero_centered_gamma = zero_centered_gamma;
......
...@@ -132,10 +132,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( ...@@ -132,10 +132,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
} }
} }
if (params.fp8_out) { if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); // Reduce amax over block
if (threadIdx.x == 0 && threadIdx.y == 0) { if (params.amax != nullptr) {
static_assert(std::is_same<compute_t, float>::value); amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax); if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
} }
} }
} }
...@@ -291,10 +299,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne ...@@ -291,10 +299,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
// Finalize fp8 factors // Finalize fp8 factors
if (params.fp8_out) { if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); // Reduce amax over block
if (threadIdx.x == 0) { if (params.amax != nullptr) {
static_assert(std::is_same<compute_t, float>::value); amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax); if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
} }
} }
} }
......
...@@ -159,6 +159,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -159,6 +159,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
params.epsilon = epsilon; params.epsilon = epsilon;
params.amax = z->amax.dptr; params.amax = z->amax.dptr;
params.scale = z->scale.dptr; params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out; params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma; params.zero_centered_gamma = zero_centered_gamma;
......
...@@ -125,10 +125,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke ...@@ -125,10 +125,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
} }
} }
if (params.fp8_out) { if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); // Reduce amax over block
if (threadIdx.x == 0 && threadIdx.y == 0) { if (params.amax != nullptr) {
static_assert(std::is_same<compute_t, float>::value); amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax); if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
} }
} }
} }
...@@ -267,10 +275,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ ...@@ -267,10 +275,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
// Finalize fp8 factors // Finalize fp8 factors
if (params.fp8_out) { if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); // Reduce amax over block
if (threadIdx.x == 0) { if (params.amax != nullptr) {
static_assert(std::is_same<compute_t, float>::value); amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax); if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
} }
} }
} }
......
...@@ -101,14 +101,11 @@ struct KernelConfig { ...@@ -101,14 +101,11 @@ struct KernelConfig {
}; };
template <size_t load_size, size_t store_size, typename IType, typename OType> template <size_t load_size, size_t store_size, typename IType, typename OType>
__global__ void __launch_bounds__(block_size) __global__ void __launch_bounds__(block_size) cast_transpose_general_kernel(
cast_transpose_general_kernel(const IType *__restrict__ const input, const IType *__restrict__ const input, const CType *__restrict__ const noop,
const CType *__restrict__ const noop, OType *__restrict__ const output_c, OType *__restrict__ const output_t,
OType *__restrict__ const output_c, const CType *__restrict__ const scale_ptr, CType *__restrict__ const amax_ptr,
OType *__restrict__ const output_t, CType *__restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) {
const CType *__restrict__ const scale_ptr,
CType *__restrict__ const amax_ptr, const size_t row_length,
const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return; if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes // Vectorized load/store sizes
...@@ -207,9 +204,15 @@ __global__ void __launch_bounds__(block_size) ...@@ -207,9 +204,15 @@ __global__ void __launch_bounds__(block_size)
if (amax_ptr != nullptr) { if (amax_ptr != nullptr) {
amax = reduce_max<warps_per_tile>(amax, tidy); amax = reduce_max<warps_per_tile>(amax, tidy);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(amax_ptr, amax); atomicMaxFloat(amax_ptr, amax);
} }
} }
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) {
reciprocal<CType>(scale_inv_ptr, scale);
}
} }
} // namespace } // namespace
...@@ -255,6 +258,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output ...@@ -255,6 +258,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
"Cast and transposed outputs need to share amax tensor."); "Cast and transposed outputs need to share amax tensor.");
NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr, NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr,
"Cast and transposed outputs need to share scale tensor."); "Cast and transposed outputs need to share scale tensor.");
NVTE_CHECK(cast_output.scale_inv.dptr == transposed_output.scale_inv.dptr,
"Cast and transposed outputs need to share scale-inverse tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, InputType, input.data.dtype, InputType,
...@@ -324,7 +329,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output ...@@ -324,7 +329,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
static_cast<OutputType *>(cast_output.data.dptr), static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType *>(transposed_output.data.dptr), static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr), static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows); static_cast<CType *>(cast_output.amax.dptr),
static_cast<CType *>(cast_output.scale_inv.dptr), row_length,
num_rows);
} else { // Statically-compiled general kernel } else { // Statically-compiled general kernel
constexpr size_t load_size = 4; constexpr size_t load_size = 4;
constexpr size_t store_size = 4; constexpr size_t store_size = 4;
...@@ -339,7 +346,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output ...@@ -339,7 +346,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
static_cast<OutputType *>(cast_output.data.dptr), static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType *>(transposed_output.data.dptr), static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr), static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows); static_cast<CType *>(cast_output.amax.dptr),
static_cast<CType *>(cast_output.scale_inv.dptr), row_length, num_rows);
}); // NOLINT(*) }); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
......
...@@ -433,15 +433,19 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) ...@@ -433,15 +433,19 @@ __global__ void __launch_bounds__(cast_transpose_num_threads)
} }
} }
/* warp tile amax reduce*/ // Reduce amax over block
amax = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(amax, warp_id); if (param.amax != nullptr) {
amax = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(amax, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) {
atomicMaxFloat(param.amax, amax); atomicMaxFloat(param.amax, amax);
} }
} }
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && param.scale_inv != nullptr) {
reciprocal<CType>(param.scale_inv, scale);
}
} }
static const char *ActTypeToString[] = { static const char *ActTypeToString[] = {
...@@ -870,17 +874,18 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) ...@@ -870,17 +874,18 @@ __global__ void __launch_bounds__(cast_transpose_num_threads)
__syncthreads(); __syncthreads();
} }
/* warp tile amax reduce*/ // Reduce amax over block
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id); if (amax != nullptr) {
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) {
atomicMaxFloat(amax, max); atomicMaxFloat(amax, max);
} }
if (scale_inv != nullptr) { }
reciprocal<float>(scale_inv, scale);
} // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<CType>(scale_inv, scale);
} }
} }
...@@ -1079,17 +1084,18 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) ...@@ -1079,17 +1084,18 @@ __global__ void __launch_bounds__(cast_transpose_num_threads)
__syncthreads(); __syncthreads();
} }
/* warp tile amax reduce*/ // Reduce amax over block
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id); if (amax != nullptr) {
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) {
atomicMaxFloat(amax, max); atomicMaxFloat(amax, max);
} }
if (scale_inv != nullptr) { }
reciprocal<float>(scale_inv, scale);
} // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<CType>(scale_inv, scale);
} }
} }
......
...@@ -36,6 +36,8 @@ struct MultiCastTransposeArgs { ...@@ -36,6 +36,8 @@ struct MultiCastTransposeArgs {
void* scale_list[kMaxTensorsPerKernel]; void* scale_list[kMaxTensorsPerKernel];
// (output) AMAX's of input tensors // (output) AMAX's of input tensors
void* amax_list[kMaxTensorsPerKernel]; void* amax_list[kMaxTensorsPerKernel];
// (output) Inverse of scaling factor for output tensors
void* scale_inv_list[kMaxTensorsPerKernel];
// Input matrix heights // Input matrix heights
int num_rows_list[kMaxTensorsPerKernel]; int num_rows_list[kMaxTensorsPerKernel];
// Input matrix widths // Input matrix widths
...@@ -82,7 +84,8 @@ __global__ void __launch_bounds__(threads_per_block) ...@@ -82,7 +84,8 @@ __global__ void __launch_bounds__(threads_per_block)
OType* output_t = reinterpret_cast<OType*>(args.output_t_list[tensor_id]); OType* output_t = reinterpret_cast<OType*>(args.output_t_list[tensor_id]);
const CType* scale_ptr = reinterpret_cast<CType*>(args.scale_list[tensor_id]); const CType* scale_ptr = reinterpret_cast<CType*>(args.scale_list[tensor_id]);
const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr; const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr;
CType* amax = reinterpret_cast<CType*>(args.amax_list[tensor_id]); CType* amax_ptr = reinterpret_cast<CType*>(args.amax_list[tensor_id]);
CType* scale_inv_ptr = reinterpret_cast<CType*>(args.scale_inv_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id]; const int num_rows = args.num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id]; const int row_length = args.row_length_list[tensor_id];
...@@ -183,7 +186,10 @@ __global__ void __launch_bounds__(threads_per_block) ...@@ -183,7 +186,10 @@ __global__ void __launch_bounds__(threads_per_block)
local_amax = reduce_max<n_warps_per_tile>(local_amax, tidy); local_amax = reduce_max<n_warps_per_tile>(local_amax, tidy);
if (tid == 0) { if (tid == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) atomicMaxFloat(amax, local_amax); if (amax_ptr != nullptr) atomicMaxFloat(amax_ptr, local_amax);
}
if (tile_id == 0 && tid == 0 && scale_inv_ptr != nullptr) {
reciprocal<CType>(scale_inv_ptr, scale);
} }
} }
...@@ -285,6 +291,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, ...@@ -285,6 +291,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr; kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr;
kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr; kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr;
kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr; kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr;
kernel_args.scale_inv_list[pos] = cast_output_list[tensor_id]->scale_inv.dptr;
kernel_args.num_rows_list[pos] = num_rows; kernel_args.num_rows_list[pos] = num_rows;
kernel_args.row_length_list[pos] = row_length; kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles; kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles;
......
...@@ -25,7 +25,7 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel( ...@@ -25,7 +25,7 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
const IType* __restrict__ const input, const CType* __restrict__ const noop, const IType* __restrict__ const input, const CType* __restrict__ const noop,
OType* __restrict__ const output_c, OType* __restrict__ const output_t, OType* __restrict__ const output_c, OType* __restrict__ const output_t,
const CType* __restrict__ const scale_ptr, CType* __restrict__ const amax_ptr, const CType* __restrict__ const scale_ptr, CType* __restrict__ const amax_ptr,
const size_t row_length, const size_t num_rows) { CType* __restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return; if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes // Vectorized load/store sizes
...@@ -121,4 +121,9 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel( ...@@ -121,4 +121,9 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
atomicMaxFloat(amax_ptr, amax); atomicMaxFloat(amax_ptr, amax);
} }
} }
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) {
reciprocal<CType>(scale_inv_ptr, scale);
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment