Unverified Commit 6a2161bf authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Remove fp8_out from the LN API (#8)



* Remove fp8_out from LN API
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>

* fix LN test
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>

* Fixes
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarksivamani <ksivamani@nvidia.com>
parent 62f93325
......@@ -163,12 +163,12 @@ void performTest(const size_t N, const size_t H) {
float epsilon = 1e-5;
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), scale.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data(), amax.data(), scale_inv.data(), true);
workspace.data(), barrier.data(), amax.data(), scale_inv.data());
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), scale.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data(), amax.data(), scale_inv.data(), true);
workspace.data(), barrier.data(), amax.data(), scale_inv.data());
// Backward kernel
nvte_layernorm_bwd(dz.data(), input.data(),
......@@ -195,6 +195,7 @@ void performTest(const size_t N, const size_t H) {
float ref_amax;
compute_ref_stats(input.cpu_dptr<InputType>(), ref_mu.get(),
ref_rsigma.get(), N, H, epsilon);
float ref_scale = isFp8Type(otype) ? *(scale.cpu_dptr<float>()) : 1.f;
compute_ref_output(input.cpu_dptr<InputType>(),
gamma.cpu_dptr<WeightType>(),
beta.cpu_dptr<WeightType>(),
......@@ -203,7 +204,7 @@ void performTest(const size_t N, const size_t H) {
rsigma.cpu_dptr<float>(),
N, H,
&ref_amax,
*(scale.cpu_dptr<float>()));
ref_scale);
compute_ref_backward(dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(),
mu.cpu_dptr<float>(), rsigma.cpu_dptr<float>(),
gamma.cpu_dptr<WeightType>(),
......@@ -215,9 +216,11 @@ void performTest(const size_t N, const size_t H) {
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) {
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>());
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
rtol_stats = 5e-5;
......
......@@ -171,4 +171,8 @@ void fillUniform(const Tensor &t) {
t.from_cpu();
}
bool isFp8Type(DType type) {
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2;
}
} // namespace test
......@@ -156,6 +156,8 @@ const std::string &typeName(DType type);
extern std::vector<DType> all_fp_types;
bool isFp8Type(DType type);
} // namespace test
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
......
......@@ -39,7 +39,6 @@ extern "C" {
* \param[out] barrier Barrier tensor.
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[in] fp8_out Whether to output FP8.
*/
void nvte_layernorm_fwd(const NVTETensor x,
const NVTETensor gamma,
......@@ -54,8 +53,7 @@ void nvte_layernorm_fwd(const NVTETensor x,
NVTETensor workspace,
NVTETensor barrier,
NVTETensor amax,
NVTETensor scale_inv,
bool fp8_out);
NVTETensor scale_inv);
/*! \brief Compute backward of LayerNorm.
......
......@@ -154,12 +154,13 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
Tensor* workspace,
Tensor* barrier,
Tensor* amax,
Tensor *scale_inv,
bool fp8_out
Tensor *scale_inv
) {
auto itype = x.dtype;
auto wtype = gamma.dtype;
auto otype = z->dtype;
bool fp8_out = otype == DType::kFloat8E4M3 ||
otype == DType::kFloat8E5M2;
auto ctype = layer_norm::DType::kFloat32;
NVTE_CHECK(x.shape.size() == 2);
......@@ -382,8 +383,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
NVTETensor workspace,
NVTETensor barrier,
NVTETensor amax,
NVTETensor scale_inv,
bool fp8_out) {
NVTETensor scale_inv) {
using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(gamma),
......@@ -398,8 +398,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(scale_inv),
fp8_out);
reinterpret_cast<Tensor*>(scale_inv));
}
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
......
......@@ -69,7 +69,7 @@ void ln_fwd_tuned_kernel(FwdParams params) {
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
compute_t scale;
compute_t scale = 1.f;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t*>(params.scale);
}
......
......@@ -119,8 +119,7 @@ void dispatch_layernorm(void* input, // i
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type,
const int multiProcessorCount,
const bool fp8_out
const int multiProcessorCount
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto gamma_cu = makeTransformerEngineTensor(gamma, gamma_shape, gamma_type);
......@@ -139,7 +138,7 @@ void dispatch_layernorm(void* input, // i
z_cu.data(), mu_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(), multiProcessorCount,
workspace.data(), barrier.data(), amax_cu.data(),
scale_inv_cu.data(), fp8_out);
scale_inv_cu.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
......@@ -160,7 +159,7 @@ void dispatch_layernorm(void* input, // i
z_cu.data(), mu_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(), multiProcessorCount,
workspace.data(), barrier.data(), amax_cu.data(),
scale_inv_cu.data(), fp8_out);
scale_inv_cu.data());
}
......
......@@ -165,8 +165,7 @@ void dispatch_layernorm(void* input, // i
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type,
const int multiProcessorCount,
const bool fp8_out
const int multiProcessorCount
);
......
......@@ -307,8 +307,7 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
rsigma.data_ptr(), {N}, DType::kFloat32,
amax.data_ptr(), {1}, DType::kFloat32,
scale_inv.data_ptr(), {1}, DType::kFloat32,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
true);
at::cuda::getCurrentDeviceProperties()->multiProcessorCount);
return {ln_out, mu, rsigma};
}
......@@ -340,8 +339,7 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
rsigma.data_ptr(), {N}, DType::kFloat32,
nullptr, {1}, DType::kFloat32,
nullptr, {1}, DType::kFloat32,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
false);
at::cuda::getCurrentDeviceProperties()->multiProcessorCount);
return {ln_out, mu, rsigma};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment