Unverified Commit 51eb6362 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Fuse amax computation into normalization kernel for current scaling (#2013)



* Compute amax in normalization kernels as long as the pointer is provided, even if using non quantized output
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fuse amax computation into normalization forward
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Use TE lahyernorm kernel instead of raising error about unsupported cuDNN feature
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 44a581c1
...@@ -65,6 +65,10 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -65,6 +65,10 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool is_aligned = true; bool is_aligned = true;
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output
}
bool gamma_in_weight_dtype = false; bool gamma_in_weight_dtype = false;
if (cudnn_backend) { if (cudnn_backend) {
// TODO: add check for GPU ARCH // TODO: add check for GPU ARCH
......
...@@ -75,6 +75,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( ...@@ -75,6 +75,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
scale = *reinterpret_cast<compute_t *>(params.scale); scale = *reinterpret_cast<compute_t *>(params.scale);
} }
compute_t amax = 0; compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
Ivec x[LDGS]; Ivec x[LDGS];
...@@ -120,9 +121,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( ...@@ -120,9 +121,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
compute_t b_ij = beta[it].data.elt[jt]; compute_t b_ij = beta[it].data.elt[jt];
compute_t temp_output = g_ij * y_ij + b_ij; compute_t temp_output = g_ij * y_ij + b_ij;
if (params.fp8_out) { if (requires_amax) {
__builtin_assume(amax >= 0); __builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output)); amax = fmaxf(amax, fabsf(temp_output));
}
if (params.fp8_out) {
temp_output = temp_output * scale; temp_output = temp_output * scale;
} }
...@@ -132,16 +135,17 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( ...@@ -132,16 +135,17 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
idx += VEC_COLS_PER_LDG; idx += VEC_COLS_PER_LDG;
} }
} }
if (params.fp8_out) {
// Reduce amax over block // Reduce amax over block
if (params.amax != nullptr) { if (requires_amax) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value); static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax); atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
} }
}
if (params.fp8_out) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale); reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
......
...@@ -51,6 +51,10 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -51,6 +51,10 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool is_aligned = true; bool is_aligned = true;
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output
}
bool training = bool training =
is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr;
......
...@@ -71,6 +71,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke ...@@ -71,6 +71,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
scale = *reinterpret_cast<compute_t *>(params.scale); scale = *reinterpret_cast<compute_t *>(params.scale);
} }
compute_t amax = 0; compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
Ivec x[LDGS]; Ivec x[LDGS];
...@@ -112,9 +113,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke ...@@ -112,9 +113,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
} }
compute_t temp_output = g_ij * y_ij; compute_t temp_output = g_ij * y_ij;
if (params.fp8_out) { if (requires_amax) {
__builtin_assume(amax >= 0); __builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output)); amax = fmaxf(amax, fabsf(temp_output));
}
if (params.fp8_out) {
temp_output = temp_output * scale; temp_output = temp_output * scale;
} }
...@@ -124,16 +127,17 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke ...@@ -124,16 +127,17 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
idx += VEC_COLS_PER_LDG; idx += VEC_COLS_PER_LDG;
} }
} }
if (params.fp8_out) {
// Reduce amax over block // Reduce amax over block
if (params.amax != nullptr) { if (requires_amax) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value); static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax); atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
} }
}
if (params.fp8_out) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale); reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
......
...@@ -110,8 +110,14 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -110,8 +110,14 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
TensorWrapper unquantized_out_cu; TensorWrapper unquantized_out_cu;
py::object unquantized_out; py::object unquantized_out;
if (force_unfused_kernel) { if (force_unfused_kernel) {
NoneQuantizer q{none}; if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
} else {
NoneQuantizer q{none};
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype);
}
} }
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
...@@ -139,7 +145,12 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -139,7 +145,12 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel // Quantize output if using unfused kernel
if (force_unfused_kernel) { if (force_unfused_kernel) {
my_quantizer->quantize(unquantized_out_cu, out_cu); if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu);
} else {
my_quantizer->quantize(unquantized_out_cu, out_cu);
}
} }
return {out, py::cast(mu), py::cast(rsigma)}; return {out, py::cast(mu), py::cast(rsigma)};
...@@ -233,8 +244,14 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -233,8 +244,14 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
TensorWrapper unquantized_out_cu; TensorWrapper unquantized_out_cu;
py::object unquantized_out; py::object unquantized_out;
if (force_unfused_kernel) { if (force_unfused_kernel) {
NoneQuantizer q{none}; if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
} else {
NoneQuantizer q{none};
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype);
}
} }
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
...@@ -262,7 +279,12 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -262,7 +279,12 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel // Quantize output if using unfused kernel
if (force_unfused_kernel) { if (force_unfused_kernel) {
my_quantizer->quantize(unquantized_out_cu, out_cu); if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu);
} else {
my_quantizer->quantize(unquantized_out_cu, out_cu);
}
} }
return {out, py::none(), py::cast(rsigma)}; return {out, py::none(), py::cast(rsigma)};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment