"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "0cf10d1c0d8b7fedb5189380409357d0b0eebf90"
Unverified Commit cd37379d authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Fix the failing test cases in the CI (#1806)



* Modify the test cases
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Make the tests reproducible on different machines
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixed the cache of the gamma_in_weight_dtype setting
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Reinstate the tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* More verbose code and comments
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent fe9a786c
...@@ -375,7 +375,7 @@ std::vector<std::pair<size_t, size_t>> matrix_sizes = { ...@@ -375,7 +375,7 @@ std::vector<std::pair<size_t, size_t>> matrix_sizes = {
{256, 256}, {256, 256},
{993, 512}, {993, 512},
{768, 1024}, {768, 1024},
{65536, 128}, {65504, 128},
{16384, 1632}, {16384, 1632},
}; };
......
...@@ -694,6 +694,19 @@ std::pair<double, double> getTolerances(const DType type) { ...@@ -694,6 +694,19 @@ std::pair<double, double> getTolerances(const DType type) {
template <typename T> template <typename T>
void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
// Check how many RNG calls are required to generate one uniform random value
int rng_calls_per_val = 0;
{
std::mt19937 gen1 = *gen, gen2 = *gen;
std::uniform_real_distribution<> dis(-2.0, 1.0);
const float _ = dis(gen1);
while (gen2 != gen1) {
auto _ = gen2();
++rng_calls_per_val;
}
}
// Generate uniform random values in parallel
#pragma omp parallel proc_bind(spread) #pragma omp parallel proc_bind(spread)
{ {
std::mt19937 gen_local = *gen; std::mt19937 gen_local = *gen;
...@@ -702,14 +715,14 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { ...@@ -702,14 +715,14 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
const int chunk_size = (size + threads_num - 1) / threads_num; const int chunk_size = (size + threads_num - 1) / threads_num;
const int idx_min = chunk_size * thread_ID; const int idx_min = chunk_size * thread_ID;
const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size)); const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size));
gen_local.discard(idx_min); gen_local.discard(idx_min * rng_calls_per_val);
std::uniform_real_distribution<> dis(-2.0, 1.0); std::uniform_real_distribution<> dis(-2.0, 1.0);
for (int i = idx_min; i < idx_max; ++i) { for (int i = idx_min; i < idx_max; ++i) {
data[i] = static_cast<T>(dis(gen_local)); data[i] = static_cast<T>(dis(gen_local));
} }
} }
gen->discard(size); gen->discard(size * rng_calls_per_val);
} }
void fillUniform(Tensor *t) { void fillUniform(Tensor *t) {
......
...@@ -185,7 +185,7 @@ def _get_tolerances(dtype): ...@@ -185,7 +185,7 @@ def _get_tolerances(dtype):
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
return {"rtol": 1.6e-2, "atol": 1e-5} return {"rtol": 1.6e-2, "atol": 1e-5}
if dtype == torch.float32: if dtype == torch.float32:
return {"rtol": 1.3e-6, "atol": 4e-5} return {"rtol": 1e-4, "atol": 1e-4}
raise ValueError(f"Unsupported dtype ({dtype})") raise ValueError(f"Unsupported dtype ({dtype})")
......
...@@ -39,8 +39,6 @@ Compute always in FP32 ...@@ -39,8 +39,6 @@ Compute always in FP32
namespace transformer_engine { namespace transformer_engine {
namespace normalization { namespace normalization {
bool& use_zero_centered_gamma_in_weight_dtype();
cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) {
return training ? cudnn_frontend::NormFwdPhase_t::TRAINING return training ? cudnn_frontend::NormFwdPhase_t::TRAINING
: cudnn_frontend::NormFwdPhase_t::INFERENCE; : cudnn_frontend::NormFwdPhase_t::INFERENCE;
...@@ -49,13 +47,17 @@ cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { ...@@ -49,13 +47,17 @@ cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) {
TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType,
NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype,
uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma,
bool is_tuned, NVTEScalingMode mode, bool training) { bool is_tuned, NVTEScalingMode mode, bool training,
// TODO: Add scaling_mode to general_key is needed bool gamma_in_weight_dtype) {
uint64_t general_key = static_cast<uint32_t>(itype) | (static_cast<uint32_t>(otype) << 3) | static_assert(NVTE_INVALID_SCALING < 1024,
(static_cast<uint32_t>(ctype) << 6) | (static_cast<uint32_t>(wtype) << 9) | "This function assumes at most 10 bits used in the scaling mode.");
(uint32_t(NormType) << 12) | (uint32_t(NormStage)) << 14 | static_assert(kNVTENumTypes < 32, "This function assumes at most 5 bits used in the NVTEDType");
(uint32_t(NormBackend) << 16) | (uint32_t(zero_centered_gamma) << 18) | uint64_t general_key = static_cast<uint64_t>(itype) | (static_cast<uint64_t>(otype) << 5) |
(uint32_t(mode) << 19) | (uint32_t(training) << 22); (static_cast<uint64_t>(ctype) << 10) |
(static_cast<uint64_t>(wtype) << 15) | (uint64_t(NormType) << 20) |
(uint64_t(NormStage)) << 22 | (uint64_t(NormBackend) << 24) |
(uint64_t(zero_centered_gamma) << 26) | (uint64_t(mode) << 27) |
(uint64_t(training) << 37) | (uint64_t(gamma_in_weight_dtype) << 38);
return std::make_tuple(general_key, batch_size, hidden_size, is_tuned); return std::make_tuple(general_key, batch_size, hidden_size, is_tuned);
} }
...@@ -466,11 +468,12 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan( ...@@ -466,11 +468,12 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype,
DType itype, DType otype, const size_t batch_size, const size_t hidden_size, DType itype, DType otype, const size_t batch_size, const size_t hidden_size,
const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned, const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned,
const NVTEScalingMode mode, const bool training) { const NVTEScalingMode mode, const bool training, const bool gamma_in_weight_dtype) {
const DType ctype = DType::kFloat32; const DType ctype = DType::kFloat32;
bool is_tuned = is_aligned && (batch_size % 4 == 0); bool is_tuned = is_aligned && (batch_size % 4 == 0);
auto key = get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size, auto key =
hidden_size, zero_centered_gamma, is_tuned, mode, training); get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size,
zero_centered_gamma, is_tuned, mode, training, gamma_in_weight_dtype);
auto it = normalizationPlanMap.find(key); auto it = normalizationPlanMap.find(key);
if (it != normalizationPlanMap.end()) { if (it != normalizationPlanMap.end()) {
...@@ -528,6 +531,7 @@ void nvte_enable_cudnn_norm_bwd(bool enable) { ...@@ -528,6 +531,7 @@ void nvte_enable_cudnn_norm_bwd(bool enable) {
transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable; transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable;
} }
// Only for testing, not thread-safe
void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) { void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) {
NVTE_API_CALL(nvte_enable_zero_centered_gamma_in_weight_dtype); NVTE_API_CALL(nvte_enable_zero_centered_gamma_in_weight_dtype);
transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = enable; transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = enable;
......
...@@ -159,7 +159,7 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, ...@@ -159,7 +159,7 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType,
NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype,
uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma,
bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING,
bool training = true); bool training = true, bool gamma_in_weight_dtype = false);
template <typename KernelParamsType> template <typename KernelParamsType>
class TeNormalizationRegistry { class TeNormalizationRegistry {
...@@ -307,7 +307,8 @@ class NormalizationPlanRegistry { ...@@ -307,7 +307,8 @@ class NormalizationPlanRegistry {
NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage,
DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size, DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size,
const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned, const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned,
const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true); const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true,
const bool gamma_in_weight_dtype = false);
private: private:
NormalizationPlanRegistry() {} NormalizationPlanRegistry() {}
...@@ -381,6 +382,8 @@ bool is_ptr_aligned(const Args*... ptrs) { ...@@ -381,6 +382,8 @@ bool is_ptr_aligned(const Args*... ptrs) {
bool use_cudnn_norm_fwd(); bool use_cudnn_norm_fwd();
bool use_cudnn_norm_bwd(); bool use_cudnn_norm_bwd();
bool& use_zero_centered_gamma_in_weight_dtype();
} // namespace normalization } // namespace normalization
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "../../common.h" #include "../../common.h"
#include "../common.h" #include "../common.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -64,9 +65,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -64,9 +65,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool is_aligned = true; bool is_aligned = true;
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
bool gamma_in_weight_dtype = false;
if (cudnn_backend) { if (cudnn_backend) {
// TODO: add check for GPU ARCH // TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn; norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else { } else {
norm_backend = NVTE_Norm_Backend::Te; norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr, is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr,
...@@ -83,7 +86,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -83,7 +86,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
z->data.dtype, // otype z->data.dtype, // otype
x.data.shape[0], // batch_size x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training); multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) { if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape(); workspace->data.shape = plan->getWorkspaceShape();
...@@ -150,9 +154,11 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te ...@@ -150,9 +154,11 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
NVTE_Norm_Backend norm_backend; NVTE_Norm_Backend norm_backend;
bool is_aligned = true; bool is_aligned = true;
bool gamma_in_weight_dtype = false;
if (use_cudnn_norm_bwd()) { if (use_cudnn_norm_bwd()) {
// TODO: add check for GPU ARCH // TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn; norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else { } else {
norm_backend = NVTE_Norm_Backend::Te; norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr, is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr,
...@@ -165,7 +171,8 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te ...@@ -165,7 +171,8 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
gamma.data.dtype, // otype gamma.data.dtype, // otype
x.data.shape[0], // batch_size x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned); multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) { if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape(); workspace->data.shape = plan->getWorkspaceShape();
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "../../common.h" #include "../../common.h"
#include "../common.h" #include "../common.h"
#include "transformer_engine/normalization.h" #include "transformer_engine/normalization.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h" #include "transformer_engine/transpose.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -53,9 +54,11 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -53,9 +54,11 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool training = bool training =
is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr;
bool gamma_in_weight_dtype = false;
if (cudnn_backend) { if (cudnn_backend) {
// TODO: add check for GPU ARCH // TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn; norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else { } else {
norm_backend = NVTE_Norm_Backend::Te; norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, rsigma->data.dptr); is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, rsigma->data.dptr);
...@@ -68,7 +71,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -68,7 +71,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
z->data.dtype, // otype z->data.dtype, // otype
x.data.shape[0], // batch_size x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training); multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) { if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape(); workspace->data.shape = plan->getWorkspaceShape();
...@@ -126,9 +130,11 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const ...@@ -126,9 +130,11 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
NVTE_Norm_Backend norm_backend; NVTE_Norm_Backend norm_backend;
bool is_aligned = true; bool is_aligned = true;
bool gamma_in_weight_dtype = false;
if (use_cudnn_norm_bwd()) { if (use_cudnn_norm_bwd()) {
// TODO: add check for GPU ARCH // TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn; norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else { } else {
norm_backend = NVTE_Norm_Backend::Te; norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr, is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr,
...@@ -141,7 +147,8 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const ...@@ -141,7 +147,8 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
gamma.data.dtype, // otype gamma.data.dtype, // otype
x.data.shape[0], // batch_size x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned); multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) { if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape(); workspace->data.shape = plan->getWorkspaceShape();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment