Unverified Commit d5c088da authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Use unoptimized layernorm kernel if pointers are not aligned (#490)


Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
parent 8d62d5c2
...@@ -233,14 +233,6 @@ struct BwdGeneralRegistrar{ ...@@ -233,14 +233,6 @@ struct BwdGeneralRegistrar{
} }
}; };
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::BwdFunction & get_bwd_launcher(DType wtype,
DType itype,
DType otype,
DType ctype,
uint32_t hidden_size);
////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm } // namespace layer_norm
......
...@@ -5,7 +5,10 @@ ...@@ -5,7 +5,10 @@
************************************************************************/ ************************************************************************/
#include <transformer_engine/layer_norm.h> #include <transformer_engine/layer_norm.h>
#include <cstdint>
#include <vector> #include <vector>
#include "ln.h" #include "ln.h"
#include "../common.h" #include "../common.h"
...@@ -72,11 +75,20 @@ layer_norm::FwdFunction & get_fwd_launcher(DType wtype, ...@@ -72,11 +75,20 @@ layer_norm::FwdFunction & get_fwd_launcher(DType wtype,
DType itype, DType itype,
DType otype, DType otype,
DType ctype, DType ctype,
uint32_t hidden_size, const layer_norm::FwdParams &params) {
uint32_t batch_size) {
// Look for tuned kernel // Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, hidden_size); auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
if (batch_size % 4 == 0 auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0
&& is_aligned(params.x)
&& is_aligned(params.mu)
&& is_aligned(params.rs)
&& is_aligned(params.gamma)
&& is_aligned(params.beta)
&& is_aligned(params.z)
&& layer_norm::FWD_TUNED_FUNCS.count(tuned_key) > 0) { && layer_norm::FWD_TUNED_FUNCS.count(tuned_key) > 0) {
return layer_norm::FWD_TUNED_FUNCS.at(tuned_key); return layer_norm::FWD_TUNED_FUNCS.at(tuned_key);
} }
...@@ -87,7 +99,7 @@ layer_norm::FwdFunction & get_fwd_launcher(DType wtype, ...@@ -87,7 +99,7 @@ layer_norm::FwdFunction & get_fwd_launcher(DType wtype,
NVTE_ERROR("FWD: Unsupported types."); NVTE_ERROR("FWD: Unsupported types.");
} }
auto& general_func_map = layer_norm::FWD_GENERAL_FUNCS.at(general_key); auto& general_func_map = layer_norm::FWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(hidden_size); auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) { if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA // Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second; return general_func_map.rbegin()->second;
...@@ -102,11 +114,24 @@ layer_norm::BwdFunction & get_bwd_launcher(DType wtype, ...@@ -102,11 +114,24 @@ layer_norm::BwdFunction & get_bwd_launcher(DType wtype,
DType itype, DType itype,
DType otype, DType otype,
DType ctype, DType ctype,
uint32_t hidden_size, const layer_norm::BwdParams &params) {
uint32_t batch_size) {
// Look for tuned kernel // Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, hidden_size); auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
if (batch_size % 4 == 0 auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0
&& is_aligned(params.x)
&& is_aligned(params.mu)
&& is_aligned(params.rs)
&& is_aligned(params.gamma)
&& is_aligned(params.dz)
&& is_aligned(params.dx)
&& is_aligned(params.dbeta)
&& is_aligned(params.dgamma)
&& is_aligned(params.dbeta_part)
&& is_aligned(params.dgamma_part)
&& layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) { && layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) {
return layer_norm::BWD_TUNED_FUNCS.at(tuned_key); return layer_norm::BWD_TUNED_FUNCS.at(tuned_key);
} }
...@@ -117,7 +142,7 @@ layer_norm::BwdFunction & get_bwd_launcher(DType wtype, ...@@ -117,7 +142,7 @@ layer_norm::BwdFunction & get_bwd_launcher(DType wtype,
NVTE_ERROR("BWD: Unsupported types."); NVTE_ERROR("BWD: Unsupported types.");
} }
auto& general_func_map = layer_norm::BWD_GENERAL_FUNCS.at(general_key); auto& general_func_map = layer_norm::BWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(hidden_size); auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) { if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA // Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second; return general_func_map.rbegin()->second;
...@@ -183,10 +208,6 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -183,10 +208,6 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
launch_params.multiprocessorCount = multiprocessorCount; launch_params.multiprocessorCount = multiprocessorCount;
launch_params.stream = stream; launch_params.stream = stream;
// Request the kernel launcher.
auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype,
hidden_size, rows);
// Set the kernel runtime parameters. // Set the kernel runtime parameters.
layer_norm::FwdParams &params = launch_params.params; layer_norm::FwdParams &params = launch_params.params;
params.rows = rows; params.rows = rows;
...@@ -203,6 +224,9 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -203,6 +224,9 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
params.fp8_out = fp8_out; params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma; params.zero_centered_gamma = zero_centered_gamma;
// Request the kernel launcher.
auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -304,9 +328,6 @@ void layernorm_bwd(const Tensor& dz, ...@@ -304,9 +328,6 @@ void layernorm_bwd(const Tensor& dz,
launch_params.stream = stream; launch_params.stream = stream;
launch_params.multiprocessorCount = multiprocessorCount; launch_params.multiprocessorCount = multiprocessorCount;
auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype,
hidden_size, rows);
// Set the kernel runtime parameters. // Set the kernel runtime parameters.
layer_norm::BwdParams &params = launch_params.params; layer_norm::BwdParams &params = launch_params.params;
params.rows = rows; params.rows = rows;
...@@ -323,6 +344,8 @@ void layernorm_bwd(const Tensor& dz, ...@@ -323,6 +344,8 @@ void layernorm_bwd(const Tensor& dz,
params.dgamma_part = dgamma_part->data.dptr; params.dgamma_part = dgamma_part->data.dptr;
params.zero_centered_gamma = zero_centered_gamma; params.zero_centered_gamma = zero_centered_gamma;
auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
......
...@@ -290,7 +290,8 @@ struct Vec { ...@@ -290,7 +290,8 @@ struct Vec {
size_t idx = 0, size_t idx = 0,
size_t count = NUM_ELT) { size_t count = NUM_ELT) {
const Elt_type *elt_ptr = static_cast<const Elt_type *>(base_ptr) + idx; const Elt_type *elt_ptr = static_cast<const Elt_type *>(base_ptr) + idx;
if ( count < NUM_ELT || idx % NUM_ELT != 0 ) { if ( count < NUM_ELT
|| reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0 ) {
#pragma unroll #pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) { for ( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = (it < count this->data.elt[it] = (it < count
...@@ -308,7 +309,8 @@ struct Vec { ...@@ -308,7 +309,8 @@ struct Vec {
size_t idx = 0, size_t idx = 0,
size_t count = NUM_ELT) const { size_t count = NUM_ELT) const {
Elt_type *elt_ptr = static_cast<Elt_type *>(base_ptr) + idx; Elt_type *elt_ptr = static_cast<Elt_type *>(base_ptr) + idx;
if ( count < NUM_ELT || idx % NUM_ELT != 0 ) { if ( count < NUM_ELT
|| reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0 ) {
#pragma unroll #pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) { for ( int it = 0; it < NUM_ELT; it++ ) {
if ( it < count ) { if ( it < count ) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment