"vscode:/vscode.git/clone" did not exist on "48949222c637da9fc72b0ed6526ae1b40bb55237"
Unverified Commit 905d94f4 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Use unoptimized RMSNorm kernel if pointers are not aligned (#886)


Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent e706e5fa
...@@ -4,11 +4,14 @@ ...@@ -4,11 +4,14 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "transformer_engine/rmsnorm.h"
#include <cstdint>
#include <numeric> #include <numeric>
#include <vector> #include <vector>
#include "../common.h"
#include "rmsnorm.h" #include "rmsnorm.h"
#include "transformer_engine/rmsnorm.h" #include "../common.h"
/* /*
...@@ -46,11 +49,23 @@ BwdTunedRegistry BWD_TUNED_FUNCS; ...@@ -46,11 +49,23 @@ BwdTunedRegistry BWD_TUNED_FUNCS;
FwdGeneralRegistry FWD_GENERAL_FUNCS; FwdGeneralRegistry FWD_GENERAL_FUNCS;
BwdGeneralRegistry BWD_GENERAL_FUNCS; BwdGeneralRegistry BWD_GENERAL_FUNCS;
FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype, FwdFunction &get_fwd_launcher(DType wtype,
uint32_t hidden_size, uint32_t batch_size) { DType itype,
DType otype,
DType ctype,
const layer_norm::FwdParams &params) {
// Look for tuned kernel // Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, hidden_size); auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
if (batch_size % 4 == 0 && FWD_TUNED_FUNCS.count(tuned_key) > 0) { auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0
&& is_aligned(params.x)
&& is_aligned(params.rs)
&& is_aligned(params.gamma)
&& is_aligned(params.z)
&& FWD_TUNED_FUNCS.count(tuned_key) > 0) {
return FWD_TUNED_FUNCS.at(tuned_key); return FWD_TUNED_FUNCS.at(tuned_key);
} }
...@@ -60,7 +75,7 @@ FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype ...@@ -60,7 +75,7 @@ FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype
NVTE_ERROR("FWD: Unsupported types."); NVTE_ERROR("FWD: Unsupported types.");
} }
auto &general_func_map = FWD_GENERAL_FUNCS.at(general_key); auto &general_func_map = FWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(hidden_size); auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) { if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA // Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second; return general_func_map.rbegin()->second;
...@@ -71,11 +86,26 @@ FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype ...@@ -71,11 +86,26 @@ FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype, BwdFunction &get_bwd_launcher(DType wtype,
uint32_t hidden_size, uint32_t batch_size) { DType itype,
DType otype,
DType ctype,
const layer_norm::BwdParams &params) {
// Look for tuned kernel // Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, hidden_size); auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
if (batch_size % 4 == 0 && BWD_TUNED_FUNCS.count(tuned_key) > 0) { auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0
&& is_aligned(params.x)
&& is_aligned(params.rs)
&& is_aligned(params.gamma)
&& is_aligned(params.dz)
&& is_aligned(params.dx)
&& is_aligned(params.dgamma)
&& is_aligned(params.dgamma_part)
&& layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) {
return BWD_TUNED_FUNCS.at(tuned_key); return BWD_TUNED_FUNCS.at(tuned_key);
} }
...@@ -85,7 +115,7 @@ BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype ...@@ -85,7 +115,7 @@ BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype
NVTE_ERROR("BWD: Unsupported types."); NVTE_ERROR("BWD: Unsupported types.");
} }
auto &general_func_map = BWD_GENERAL_FUNCS.at(general_key); auto &general_func_map = BWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(hidden_size); auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) { if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA // Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second; return general_func_map.rbegin()->second;
...@@ -132,9 +162,6 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -132,9 +162,6 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
launch_params.multiprocessorCount = multiprocessorCount; launch_params.multiprocessorCount = multiprocessorCount;
launch_params.stream = stream; launch_params.stream = stream;
// Request the kernel launcher.
auto launcher = rmsnorm::get_fwd_launcher(wtype, itype, otype, ctype, hidden_size, rows);
// Set the kernel runtime parameters. // Set the kernel runtime parameters.
rmsnorm::FwdParams &params = launch_params.params; rmsnorm::FwdParams &params = launch_params.params;
params.rows = rows; params.rows = rows;
...@@ -151,6 +178,9 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -151,6 +178,9 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
params.fp8_out = fp8_out; params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma; params.zero_centered_gamma = zero_centered_gamma;
// Request the kernel launcher.
auto launcher = rmsnorm::get_fwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
if (launch_params.workspace_bytes == 0) { if (launch_params.workspace_bytes == 0) {
...@@ -242,8 +272,6 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const ...@@ -242,8 +272,6 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
launch_params.stream = stream; launch_params.stream = stream;
launch_params.multiprocessorCount = multiprocessorCount; launch_params.multiprocessorCount = multiprocessorCount;
auto launcher = rmsnorm::get_bwd_launcher(wtype, itype, otype, ctype, hidden_size, rows);
// Set the kernel runtime parameters. // Set the kernel runtime parameters.
rmsnorm::BwdParams &params = launch_params.params; rmsnorm::BwdParams &params = launch_params.params;
params.rows = rows; params.rows = rows;
...@@ -260,6 +288,9 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const ...@@ -260,6 +288,9 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
params.dgamma_part = dgamma_part->data.dptr; params.dgamma_part = dgamma_part->data.dptr;
params.zero_centered_gamma = zero_centered_gamma; params.zero_centered_gamma = zero_centered_gamma;
// Request the kernel launcher.
auto launcher = rmsnorm::get_bwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment