Unverified Commit 51eb6362 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Fuse amax computation into normalization kernel for current scaling (#2013)



* Compute amax in normalization kernels as long as the pointer is provided, even if using non quantized output
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fuse amax computation into normalization forward
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Use TE lahyernorm kernel instead of raising error about unsupported cuDNN feature
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 44a581c1
......@@ -65,6 +65,10 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool is_aligned = true;
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output
}
bool gamma_in_weight_dtype = false;
if (cudnn_backend) {
// TODO: add check for GPU ARCH
......
......@@ -75,6 +75,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
Ivec x[LDGS];
......@@ -120,9 +121,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
compute_t b_ij = beta[it].data.elt[jt];
compute_t temp_output = g_ij * y_ij + b_ij;
if (params.fp8_out) {
if (requires_amax) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
}
if (params.fp8_out) {
temp_output = temp_output * scale;
}
......@@ -132,9 +135,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
idx += VEC_COLS_PER_LDG;
}
}
if (params.fp8_out) {
// Reduce amax over block
if (params.amax != nullptr) {
if (requires_amax) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
......@@ -142,6 +145,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
}
}
if (params.fp8_out) {
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
......
......@@ -51,6 +51,10 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool is_aligned = true;
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output
}
bool training =
is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr;
......
......@@ -71,6 +71,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
Ivec x[LDGS];
......@@ -112,9 +113,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
}
compute_t temp_output = g_ij * y_ij;
if (params.fp8_out) {
if (requires_amax) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
}
if (params.fp8_out) {
temp_output = temp_output * scale;
}
......@@ -124,9 +127,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
idx += VEC_COLS_PER_LDG;
}
}
if (params.fp8_out) {
// Reduce amax over block
if (params.amax != nullptr) {
if (requires_amax) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
......@@ -134,6 +137,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
}
}
if (params.fp8_out) {
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
......
......@@ -110,9 +110,15 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
TensorWrapper unquantized_out_cu;
py::object unquantized_out;
if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
} else {
NoneQuantizer q{none};
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype);
}
}
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size
......@@ -139,8 +145,13 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu);
} else {
my_quantizer->quantize(unquantized_out_cu, out_cu);
}
}
return {out, py::cast(mu), py::cast(rsigma)};
}
......@@ -233,9 +244,15 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
TensorWrapper unquantized_out_cu;
py::object unquantized_out;
if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
} else {
NoneQuantizer q{none};
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype);
}
}
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size
......@@ -262,8 +279,13 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu);
} else {
my_quantizer->quantize(unquantized_out_cu, out_cu);
}
}
return {out, py::none(), py::cast(rsigma)};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment