Unverified Commit 6a2161bf authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Remove fp8_out from the LN API (#8)



* Remove fp8_out from LN API
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>

* fix LN test
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>

* Fixes
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarksivamani <ksivamani@nvidia.com>
parent 62f93325
...@@ -163,12 +163,12 @@ void performTest(const size_t N, const size_t H) { ...@@ -163,12 +163,12 @@ void performTest(const size_t N, const size_t H) {
float epsilon = 1e-5; float epsilon = 1e-5;
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), scale.data(), epsilon, nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), scale.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount, z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data(), amax.data(), scale_inv.data(), true); workspace.data(), barrier.data(), amax.data(), scale_inv.data());
workspace = Tensor(workspace.shape(), workspace.dtype()); workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype()); barrier = Tensor(barrier.shape(), barrier.dtype());
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), scale.data(), epsilon, nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), scale.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount, z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data(), amax.data(), scale_inv.data(), true); workspace.data(), barrier.data(), amax.data(), scale_inv.data());
// Backward kernel // Backward kernel
nvte_layernorm_bwd(dz.data(), input.data(), nvte_layernorm_bwd(dz.data(), input.data(),
...@@ -195,6 +195,7 @@ void performTest(const size_t N, const size_t H) { ...@@ -195,6 +195,7 @@ void performTest(const size_t N, const size_t H) {
float ref_amax; float ref_amax;
compute_ref_stats(input.cpu_dptr<InputType>(), ref_mu.get(), compute_ref_stats(input.cpu_dptr<InputType>(), ref_mu.get(),
ref_rsigma.get(), N, H, epsilon); ref_rsigma.get(), N, H, epsilon);
float ref_scale = isFp8Type(otype) ? *(scale.cpu_dptr<float>()) : 1.f;
compute_ref_output(input.cpu_dptr<InputType>(), compute_ref_output(input.cpu_dptr<InputType>(),
gamma.cpu_dptr<WeightType>(), gamma.cpu_dptr<WeightType>(),
beta.cpu_dptr<WeightType>(), beta.cpu_dptr<WeightType>(),
...@@ -203,7 +204,7 @@ void performTest(const size_t N, const size_t H) { ...@@ -203,7 +204,7 @@ void performTest(const size_t N, const size_t H) {
rsigma.cpu_dptr<float>(), rsigma.cpu_dptr<float>(),
N, H, N, H,
&ref_amax, &ref_amax,
*(scale.cpu_dptr<float>())); ref_scale);
compute_ref_backward(dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(), compute_ref_backward(dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(),
mu.cpu_dptr<float>(), rsigma.cpu_dptr<float>(), mu.cpu_dptr<float>(), rsigma.cpu_dptr<float>(),
gamma.cpu_dptr<WeightType>(), gamma.cpu_dptr<WeightType>(),
...@@ -215,9 +216,11 @@ void performTest(const size_t N, const size_t H) { ...@@ -215,9 +216,11 @@ void performTest(const size_t N, const size_t H) {
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax); if (isFp8Type(otype)) {
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>()); compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax);
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax); float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>());
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
rtol_stats = 5e-5; rtol_stats = 5e-5;
......
...@@ -171,4 +171,8 @@ void fillUniform(const Tensor &t) { ...@@ -171,4 +171,8 @@ void fillUniform(const Tensor &t) {
t.from_cpu(); t.from_cpu();
} }
bool isFp8Type(DType type) {
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2;
}
} // namespace test } // namespace test
...@@ -156,6 +156,8 @@ const std::string &typeName(DType type); ...@@ -156,6 +156,8 @@ const std::string &typeName(DType type);
extern std::vector<DType> all_fp_types; extern std::vector<DType> all_fp_types;
bool isFp8Type(DType type);
} // namespace test } // namespace test
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
......
...@@ -39,7 +39,6 @@ extern "C" { ...@@ -39,7 +39,6 @@ extern "C" {
* \param[out] barrier Barrier tensor. * \param[out] barrier Barrier tensor.
* \param[in,out] amax AMAX value of the output tensor. * \param[in,out] amax AMAX value of the output tensor.
* \param[out] scale_inv Inverse of the output's scaling factor. * \param[out] scale_inv Inverse of the output's scaling factor.
* \param[in] fp8_out Whether to output FP8.
*/ */
void nvte_layernorm_fwd(const NVTETensor x, void nvte_layernorm_fwd(const NVTETensor x,
const NVTETensor gamma, const NVTETensor gamma,
...@@ -54,8 +53,7 @@ void nvte_layernorm_fwd(const NVTETensor x, ...@@ -54,8 +53,7 @@ void nvte_layernorm_fwd(const NVTETensor x,
NVTETensor workspace, NVTETensor workspace,
NVTETensor barrier, NVTETensor barrier,
NVTETensor amax, NVTETensor amax,
NVTETensor scale_inv, NVTETensor scale_inv);
bool fp8_out);
/*! \brief Compute backward of LayerNorm. /*! \brief Compute backward of LayerNorm.
......
...@@ -154,12 +154,13 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -154,12 +154,13 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
Tensor* workspace, Tensor* workspace,
Tensor* barrier, Tensor* barrier,
Tensor* amax, Tensor* amax,
Tensor *scale_inv, Tensor *scale_inv
bool fp8_out
) { ) {
auto itype = x.dtype; auto itype = x.dtype;
auto wtype = gamma.dtype; auto wtype = gamma.dtype;
auto otype = z->dtype; auto otype = z->dtype;
bool fp8_out = otype == DType::kFloat8E4M3 ||
otype == DType::kFloat8E5M2;
auto ctype = layer_norm::DType::kFloat32; auto ctype = layer_norm::DType::kFloat32;
NVTE_CHECK(x.shape.size() == 2); NVTE_CHECK(x.shape.size() == 2);
...@@ -382,8 +383,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size ...@@ -382,8 +383,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
NVTETensor workspace, NVTETensor workspace,
NVTETensor barrier, NVTETensor barrier,
NVTETensor amax, NVTETensor amax,
NVTETensor scale_inv, NVTETensor scale_inv) {
bool fp8_out) {
using namespace transformer_engine; using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x), layernorm_fwd(*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(gamma), *reinterpret_cast<const Tensor*>(gamma),
...@@ -398,8 +398,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size ...@@ -398,8 +398,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
reinterpret_cast<Tensor*>(workspace), reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier), reinterpret_cast<Tensor*>(barrier),
reinterpret_cast<Tensor*>(amax), reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(scale_inv), reinterpret_cast<Tensor*>(scale_inv));
fp8_out);
} }
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
......
...@@ -69,7 +69,7 @@ void ln_fwd_tuned_kernel(FwdParams params) { ...@@ -69,7 +69,7 @@ void ln_fwd_tuned_kernel(FwdParams params) {
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
compute_t scale; compute_t scale = 1.f;
if (params.fp8_out) { if (params.fp8_out) {
scale = *reinterpret_cast<compute_t*>(params.scale); scale = *reinterpret_cast<compute_t*>(params.scale);
} }
......
...@@ -119,8 +119,7 @@ void dispatch_layernorm(void* input, // i ...@@ -119,8 +119,7 @@ void dispatch_layernorm(void* input, // i
void* scale_inv, // o void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape, const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type, const transformer_engine::DType scale_inv_type,
const int multiProcessorCount, const int multiProcessorCount
const bool fp8_out
) { ) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type); auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto gamma_cu = makeTransformerEngineTensor(gamma, gamma_shape, gamma_type); auto gamma_cu = makeTransformerEngineTensor(gamma, gamma_shape, gamma_type);
...@@ -139,7 +138,7 @@ void dispatch_layernorm(void* input, // i ...@@ -139,7 +138,7 @@ void dispatch_layernorm(void* input, // i
z_cu.data(), mu_cu.data(), rsigma_cu.data(), z_cu.data(), mu_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(), multiProcessorCount, at::cuda::getCurrentCUDAStream(), multiProcessorCount,
workspace.data(), barrier.data(), amax_cu.data(), workspace.data(), barrier.data(), amax_cu.data(),
scale_inv_cu.data(), fp8_out); scale_inv_cu.data());
// Fill workspace and barrier // Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(), auto workspace_data = allocateSpace(workspace.shape(),
...@@ -160,7 +159,7 @@ void dispatch_layernorm(void* input, // i ...@@ -160,7 +159,7 @@ void dispatch_layernorm(void* input, // i
z_cu.data(), mu_cu.data(), rsigma_cu.data(), z_cu.data(), mu_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(), multiProcessorCount, at::cuda::getCurrentCUDAStream(), multiProcessorCount,
workspace.data(), barrier.data(), amax_cu.data(), workspace.data(), barrier.data(), amax_cu.data(),
scale_inv_cu.data(), fp8_out); scale_inv_cu.data());
} }
......
...@@ -165,8 +165,7 @@ void dispatch_layernorm(void* input, // i ...@@ -165,8 +165,7 @@ void dispatch_layernorm(void* input, // i
void* scale_inv, // o void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape, const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type, const transformer_engine::DType scale_inv_type,
const int multiProcessorCount, const int multiProcessorCount
const bool fp8_out
); );
......
...@@ -307,8 +307,7 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -307,8 +307,7 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
rsigma.data_ptr(), {N}, DType::kFloat32, rsigma.data_ptr(), {N}, DType::kFloat32,
amax.data_ptr(), {1}, DType::kFloat32, amax.data_ptr(), {1}, DType::kFloat32,
scale_inv.data_ptr(), {1}, DType::kFloat32, scale_inv.data_ptr(), {1}, DType::kFloat32,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount, at::cuda::getCurrentDeviceProperties()->multiProcessorCount);
true);
return {ln_out, mu, rsigma}; return {ln_out, mu, rsigma};
} }
...@@ -340,8 +339,7 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, ...@@ -340,8 +339,7 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
rsigma.data_ptr(), {N}, DType::kFloat32, rsigma.data_ptr(), {N}, DType::kFloat32,
nullptr, {1}, DType::kFloat32, nullptr, {1}, DType::kFloat32,
nullptr, {1}, DType::kFloat32, nullptr, {1}, DType::kFloat32,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount, at::cuda::getCurrentDeviceProperties()->multiProcessorCount);
false);
return {ln_out, mu, rsigma}; return {ln_out, mu, rsigma};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment