Unverified Commit 02a3582c authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Move calculation of scale inverse to framework (#51)



* Move scale inverse calculation to framework
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix RMSNorm
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix gated kernel/geglu
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 40467fc2
...@@ -115,8 +115,6 @@ void performTest(const size_t N, const size_t H) { ...@@ -115,8 +115,6 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) { if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
} }
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
......
...@@ -131,8 +131,6 @@ void performTest(const size_t N, const size_t H) { ...@@ -131,8 +131,6 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) { if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
} }
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
......
...@@ -70,8 +70,6 @@ void performTestGEGLU(const size_t N, const size_t H) { ...@@ -70,8 +70,6 @@ void performTestGEGLU(const size_t N, const size_t H) {
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
} }
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol); compareResults("output_gelu", output, ref_output.get(), atol, rtol);
......
...@@ -70,8 +70,6 @@ void performTestGelu(const size_t N, const size_t H) { ...@@ -70,8 +70,6 @@ void performTestGelu(const size_t N, const size_t H) {
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
} }
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol); compareResults("output_gelu", output, ref_output.get(), atol, rtol);
......
...@@ -215,8 +215,6 @@ void performTest(const size_t N, const size_t H) { ...@@ -215,8 +215,6 @@ void performTest(const size_t N, const size_t H) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) { if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
} }
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
......
...@@ -27,7 +27,6 @@ void compute_ref(const std::vector<std::vector<InputType>>& input_list, ...@@ -27,7 +27,6 @@ void compute_ref(const std::vector<std::vector<InputType>>& input_list,
std::vector<std::vector<OutputType>>& output_t_list, std::vector<std::vector<OutputType>>& output_t_list,
const std::vector<float>& scale_list, const std::vector<float>& scale_list,
std::vector<float>& amax_list, std::vector<float>& amax_list,
std::vector<float>& scale_inv_list,
const std::vector<size_t>& height_list, const std::vector<size_t>& height_list,
const std::vector<size_t>& width_list) { const std::vector<size_t>& width_list) {
using compute_t = float; using compute_t = float;
...@@ -37,10 +36,8 @@ void compute_ref(const std::vector<std::vector<InputType>>& input_list, ...@@ -37,10 +36,8 @@ void compute_ref(const std::vector<std::vector<InputType>>& input_list,
auto& output_t = output_t_list[tensor_id]; auto& output_t = output_t_list[tensor_id];
const compute_t scale = scale_list[tensor_id]; const compute_t scale = scale_list[tensor_id];
compute_t& amax = amax_list[tensor_id]; compute_t& amax = amax_list[tensor_id];
compute_t& scale_inv = scale_inv_list[tensor_id];
const size_t height = height_list[tensor_id]; const size_t height = height_list[tensor_id];
const size_t width = width_list[tensor_id]; const size_t width = width_list[tensor_id];
scale_inv = 1. / scale;
amax = -1e100; amax = -1e100;
for (size_t i = 0; i < height; ++i) { for (size_t i = 0; i < height; ++i) {
for (size_t j = 0; j < width; ++j) { for (size_t j = 0; j < width; ++j) {
...@@ -76,8 +73,7 @@ void performTest() { ...@@ -76,8 +73,7 @@ void performTest() {
// Buffers for reference implementation // Buffers for reference implementation
std::vector<std::vector<InputType>> ref_input_list; std::vector<std::vector<InputType>> ref_input_list;
std::vector<std::vector<OutputType>> ref_output_c_list, ref_output_t_list; std::vector<std::vector<OutputType>> ref_output_c_list, ref_output_t_list;
std::vector<float> ref_scale_list(num_tensors), ref_amax_list(num_tensors), std::vector<float> ref_scale_list(num_tensors), ref_amax_list(num_tensors);
ref_scale_inv_list(num_tensors);
std::vector<size_t> ref_height_list(num_tensors), ref_width_list(num_tensors); std::vector<size_t> ref_height_list(num_tensors), ref_width_list(num_tensors);
// Initialize buffers // Initialize buffers
...@@ -128,7 +124,6 @@ void performTest() { ...@@ -128,7 +124,6 @@ void performTest() {
ref_output_t_list, ref_output_t_list,
ref_scale_list, ref_scale_list,
ref_amax_list, ref_amax_list,
ref_scale_inv_list,
ref_height_list, ref_height_list,
ref_width_list); ref_width_list);
...@@ -143,10 +138,6 @@ void performTest() { ...@@ -143,10 +138,6 @@ void performTest() {
output_c_list[tensor_id].amax(), output_c_list[tensor_id].amax(),
ref_amax_list[tensor_id], ref_amax_list[tensor_id],
atol_amax, rtol_amax); atol_amax, rtol_amax);
compareResults("scale_inv",
output_c_list[tensor_id].scale_inv(),
ref_scale_inv_list[tensor_id],
atol_amax, rtol_amax);
} }
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", compareResults("output_c",
......
...@@ -78,8 +78,6 @@ void performTestQ(const size_t N) { ...@@ -78,8 +78,6 @@ void performTestQ(const size_t N) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_q", output, ref_output.get(), atol, rtol); compareResults("output_q", output, ref_output.get(), atol, rtol);
} }
......
...@@ -172,8 +172,6 @@ void performTest(const size_t N, const size_t H) { ...@@ -172,8 +172,6 @@ void performTest(const size_t N, const size_t H) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) { if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
} }
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
......
...@@ -43,7 +43,6 @@ void gelu_cast(const Tensor &input, ...@@ -43,7 +43,6 @@ void gelu_cast(const Tensor &input,
reinterpret_cast<const IType*>(input.data.dptr), reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr), reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr), reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->scale_inv.dptr),
reinterpret_cast<fp32*>(output->amax.dptr), reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts, tot_elts,
{}, {},
...@@ -71,7 +70,6 @@ void geglu_cast(const Tensor &input, ...@@ -71,7 +70,6 @@ void geglu_cast(const Tensor &input,
reinterpret_cast<const IType*>(input.data.dptr), reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr), reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr), reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->scale_inv.dptr),
reinterpret_cast<fp32*>(output->amax.dptr), reinterpret_cast<fp32*>(output->amax.dptr),
output->data.shape[0], output->data.shape[0],
output->data.shape[1], output->data.shape[1],
......
...@@ -87,10 +87,6 @@ struct FwdParams : public ParamsBase { ...@@ -87,10 +87,6 @@ struct FwdParams : public ParamsBase {
// Scaling factor // Scaling factor
void *scale; void *scale;
// Scaling factor inverse,
// needed for cublasLt fp8 gemm
void *scale_inv;
// AMax output // AMax output
void *amax; void *amax;
......
...@@ -207,7 +207,6 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -207,7 +207,6 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
params.epsilon = epsilon; params.epsilon = epsilon;
params.amax = z->amax.dptr; params.amax = z->amax.dptr;
params.scale = z->scale.dptr; params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out; params.fp8_out = fp8_out;
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
......
...@@ -133,7 +133,6 @@ void ln_fwd_tuned_kernel(FwdParams params) { ...@@ -133,7 +133,6 @@ void ln_fwd_tuned_kernel(FwdParams params) {
if (threadIdx.x == 0 && threadIdx.y == 0) { if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value); static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t*>(params.amax), amax); atomicMaxFloat(reinterpret_cast<compute_t*>(params.amax), amax);
reciprocal<compute_t>(reinterpret_cast<compute_t*>(params.scale_inv), scale);
} }
} }
} }
...@@ -302,7 +301,6 @@ void ln_fwd_general_kernel(FwdParams params) { ...@@ -302,7 +301,6 @@ void ln_fwd_general_kernel(FwdParams params) {
if ( threadIdx.x == 0 ) { if ( threadIdx.x == 0 ) {
static_assert(std::is_same<compute_t, float>::value); static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t*>(params.amax), amax); atomicMaxFloat(reinterpret_cast<compute_t*>(params.amax), amax);
reciprocal<compute_t>(reinterpret_cast<compute_t*>(params.scale_inv), scale);
} }
} }
} }
......
...@@ -154,7 +154,6 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -154,7 +154,6 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
params.epsilon = epsilon; params.epsilon = epsilon;
params.amax = z->amax.dptr; params.amax = z->amax.dptr;
params.scale = z->scale.dptr; params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out; params.fp8_out = fp8_out;
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
......
...@@ -125,7 +125,6 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke ...@@ -125,7 +125,6 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
if (threadIdx.x == 0 && threadIdx.y == 0) { if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value); static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax); atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
} }
} }
} }
...@@ -266,7 +265,6 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ ...@@ -266,7 +265,6 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value); static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax); atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
} }
} }
} }
......
...@@ -62,7 +62,6 @@ cast_transpose_kernel(const IType * const input, ...@@ -62,7 +62,6 @@ cast_transpose_kernel(const IType * const input,
OType * const output_t, OType * const output_t,
const CType * const scale_ptr, const CType * const scale_ptr,
CType * const amax, CType * const amax,
CType * const scale_inv,
const size_t row_length, const size_t row_length,
const size_t num_rows, const size_t num_rows,
const size_t num_tiles) { const size_t num_tiles) {
...@@ -159,7 +158,6 @@ cast_transpose_kernel(const IType * const input, ...@@ -159,7 +158,6 @@ cast_transpose_kernel(const IType * const input,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) atomicMaxFloat(amax, max); if (amax != nullptr) atomicMaxFloat(amax, max);
if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale);
} }
} }
...@@ -171,7 +169,6 @@ cast_transpose_kernel_notaligned(const IType * const input, ...@@ -171,7 +169,6 @@ cast_transpose_kernel_notaligned(const IType * const input,
OType * const output_t, OType * const output_t,
const CType * const scale_ptr, const CType * const scale_ptr,
CType * const amax, CType * const amax,
CType * const scale_inv,
const size_t row_length, const size_t row_length,
const size_t num_rows, const size_t num_rows,
const size_t num_tiles) { const size_t num_tiles) {
...@@ -295,7 +292,6 @@ cast_transpose_kernel_notaligned(const IType * const input, ...@@ -295,7 +292,6 @@ cast_transpose_kernel_notaligned(const IType * const input,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) atomicMaxFloat(amax, max); if (amax != nullptr) atomicMaxFloat(amax, max);
if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale);
} }
} }
...@@ -324,8 +320,6 @@ void cast_transpose(const Tensor &input, ...@@ -324,8 +320,6 @@ void cast_transpose(const Tensor &input,
"C and T outputs need to share amax tensor."); "C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor."); "C and T outputs need to share scale tensor.");
NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
"C and T outputs need to share scale inverse tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
...@@ -361,7 +355,6 @@ void cast_transpose(const Tensor &input, ...@@ -361,7 +355,6 @@ void cast_transpose(const Tensor &input,
reinterpret_cast<OutputType *>(transposed_output->data.dptr), reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr), reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr), reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles); row_length, num_rows, n_tiles);
} else { } else {
cudaFuncSetAttribute(cast_transpose_kernel_notaligned<nvec_in, nvec_out, fp32, cudaFuncSetAttribute(cast_transpose_kernel_notaligned<nvec_in, nvec_out, fp32,
...@@ -379,7 +372,6 @@ void cast_transpose(const Tensor &input, ...@@ -379,7 +372,6 @@ void cast_transpose(const Tensor &input,
reinterpret_cast<OutputType *>(transposed_output->data.dptr), reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr), reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr), reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles); row_length, num_rows, n_tiles);
} }
); // NOLINT(*) ); // NOLINT(*)
......
...@@ -114,7 +114,6 @@ struct CTDBiasParam { ...@@ -114,7 +114,6 @@ struct CTDBiasParam {
OType *output_t; OType *output_t;
const CType *scale_ptr; const CType *scale_ptr;
CType *amax; CType *amax;
CType *scale_inv;
CType *workspace; CType *workspace;
}; };
...@@ -130,7 +129,6 @@ struct CTDBiasDGeluParam { ...@@ -130,7 +129,6 @@ struct CTDBiasDGeluParam {
OType *output_t; OType *output_t;
const CType *scale_ptr; const CType *scale_ptr;
CType *amax; CType *amax;
CType *scale_inv;
CType *workspace; CType *workspace;
}; };
...@@ -273,7 +271,6 @@ cast_transpose_dbias_kernel(const Param param, ...@@ -273,7 +271,6 @@ cast_transpose_dbias_kernel(const Param param,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) atomicMaxFloat(param.amax, max); if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
} }
} }
...@@ -442,7 +439,6 @@ cast_transpose_dbias_kernel_notaligned(const Param param, ...@@ -442,7 +439,6 @@ cast_transpose_dbias_kernel_notaligned(const Param param,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) atomicMaxFloat(param.amax, max); if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
} }
} }
...@@ -555,8 +551,6 @@ void cast_transpose_dbias(const Tensor &input, ...@@ -555,8 +551,6 @@ void cast_transpose_dbias(const Tensor &input,
"C and T outputs need to share amax tensor."); "C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor."); "C and T outputs need to share scale tensor.");
NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
"C and T outputs need to share scale inverse tensor.");
NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias."); NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
...@@ -597,7 +591,6 @@ void cast_transpose_dbias(const Tensor &input, ...@@ -597,7 +591,6 @@ void cast_transpose_dbias(const Tensor &input,
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr); param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr); param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(cast_output->amax.dptr); param.amax = reinterpret_cast<ComputeType *>(cast_output->amax.dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr); param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) { if (full_tile) {
...@@ -782,7 +775,6 @@ cast_transpose_dbias_dgelu_kernel(const Param param, ...@@ -782,7 +775,6 @@ cast_transpose_dbias_dgelu_kernel(const Param param,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) atomicMaxFloat(param.amax, max); if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
} }
} }
...@@ -973,7 +965,6 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param, ...@@ -973,7 +965,6 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) atomicMaxFloat(param.amax, max); if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
} }
} }
...@@ -1382,8 +1373,6 @@ void cast_transpose_dbias_dgelu(const Tensor &input, ...@@ -1382,8 +1373,6 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
"C and T outputs need to share amax tensor."); "C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor."); "C and T outputs need to share scale tensor.");
NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
"C and T outputs need to share scale inverse tensor.");
NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias."); NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
...@@ -1432,7 +1421,6 @@ void cast_transpose_dbias_dgelu(const Tensor &input, ...@@ -1432,7 +1421,6 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr); param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr); param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(cast_output->amax.dptr); param.amax = reinterpret_cast<ComputeType *>(cast_output->amax.dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr); param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) { if (full_tile) {
cudaFuncSetAttribute(cast_transpose_dbias_dgelu_kernel<nvec_in, nvec_out, Param>, cudaFuncSetAttribute(cast_transpose_dbias_dgelu_kernel<nvec_in, nvec_out, Param>,
......
...@@ -34,8 +34,6 @@ struct MultiCastTransposeArgs { ...@@ -34,8 +34,6 @@ struct MultiCastTransposeArgs {
void* scale_list[kMaxTensorsPerKernel]; void* scale_list[kMaxTensorsPerKernel];
// (output) AMAX's of input tensors // (output) AMAX's of input tensors
void* amax_list[kMaxTensorsPerKernel]; void* amax_list[kMaxTensorsPerKernel];
// (output) Reciprocal of scaling factors
void* scale_inv_list[kMaxTensorsPerKernel];
// Input matrix heights // Input matrix heights
int num_rows_list[kMaxTensorsPerKernel]; int num_rows_list[kMaxTensorsPerKernel];
// Input matrix widths // Input matrix widths
...@@ -90,7 +88,6 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -90,7 +88,6 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
const CType* scale_ptr = reinterpret_cast<CType*>(args.scale_list[tensor_id]); const CType* scale_ptr = reinterpret_cast<CType*>(args.scale_list[tensor_id]);
const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr; const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr;
CType* amax = reinterpret_cast<CType*>(args.amax_list[tensor_id]); CType* amax = reinterpret_cast<CType*>(args.amax_list[tensor_id]);
CType* scale_inv = reinterpret_cast<CType*>(args.scale_inv_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id]; const int num_rows = args.num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id]; const int row_length = args.row_length_list[tensor_id];
...@@ -193,9 +190,6 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -193,9 +190,6 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) atomicMaxFloat(amax, local_amax); if (amax != nullptr) atomicMaxFloat(amax, local_amax);
} }
if (tid == 0 && tile_id == 0) {
if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale);
}
} }
} // namespace } // namespace
...@@ -300,7 +294,6 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, ...@@ -300,7 +294,6 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr; kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr;
kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr; kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr;
kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr; kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr;
kernel_args.scale_inv_list[pos] = cast_output_list[tensor_id]->scale_inv.dptr;
kernel_args.num_rows_list[pos] = num_rows; kernel_args.num_rows_list[pos] = num_rows;
kernel_args.row_length_list[pos] = row_length; kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos+1] = kernel_args.block_range[pos] + num_tiles; kernel_args.block_range[pos+1] = kernel_args.block_range[pos] + num_tiles;
......
...@@ -50,7 +50,6 @@ void fp8_quantize(const Tensor &input, ...@@ -50,7 +50,6 @@ void fp8_quantize(const Tensor &input,
reinterpret_cast<const IType*>(input.data.dptr), reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr), reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr), reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->scale_inv.dptr),
reinterpret_cast<fp32*>(output->amax.dptr), reinterpret_cast<fp32*>(output->amax.dptr),
N, N,
{}, {},
...@@ -82,7 +81,6 @@ void fp8_dequantize(const Tensor &input, ...@@ -82,7 +81,6 @@ void fp8_dequantize(const Tensor &input,
reinterpret_cast<OType*>(output->data.dptr), reinterpret_cast<OType*>(output->data.dptr),
nullptr, nullptr,
nullptr, nullptr,
nullptr,
N, N,
p, p,
stream); stream);
......
...@@ -185,7 +185,6 @@ __launch_bounds__(unary_kernel_threads) ...@@ -185,7 +185,6 @@ __launch_bounds__(unary_kernel_threads)
__global__ void unary_kernel(const InputType *input, __global__ void unary_kernel(const InputType *input,
OutputType *output, OutputType *output,
const ComputeType *scale, const ComputeType *scale,
ComputeType *scale_inv,
ComputeType *amax, ComputeType *amax,
Param p, Param p,
const size_t N, const size_t N,
...@@ -196,9 +195,6 @@ __global__ void unary_kernel(const InputType *input, ...@@ -196,9 +195,6 @@ __global__ void unary_kernel(const InputType *input,
ComputeType s = 0; ComputeType s = 0;
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s);
}
} }
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
...@@ -295,7 +291,6 @@ template <int nvec, typename Param, ...@@ -295,7 +291,6 @@ template <int nvec, typename Param,
void VectorizedUnaryKernelLauncher(const InputType *input, void VectorizedUnaryKernelLauncher(const InputType *input,
OutputType *output, OutputType *output,
const fp32 *scale, const fp32 *scale,
fp32 *scale_inv,
fp32 *amax, fp32 *amax,
const size_t N, const size_t N,
const Param params, const Param params,
...@@ -313,16 +308,16 @@ void VectorizedUnaryKernelLauncher(const InputType *input, ...@@ -313,16 +308,16 @@ void VectorizedUnaryKernelLauncher(const InputType *input,
switch (align) { switch (align) {
case Alignment::SAME_ALIGNED: case Alignment::SAME_ALIGNED:
unary_kernel<nvec, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>( unary_kernel<nvec, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, params, N, num_aligned_elements); input, output, scale, amax, params, N, num_aligned_elements);
break; break;
case Alignment::SAME_UNALIGNED: case Alignment::SAME_UNALIGNED:
unary_kernel<nvec, false, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>( unary_kernel<nvec, false, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, params, N, num_aligned_elements); input, output, scale, amax, params, N, num_aligned_elements);
break; break;
case Alignment::DIFFERENT: { case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize // If the pointers are aligned differently we cannot vectorize
unary_kernel<1, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>( unary_kernel<1, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, params, N, N); input, output, scale, amax, params, N, N);
break; break;
} }
} }
...@@ -338,7 +333,6 @@ __launch_bounds__(unary_kernel_threads) ...@@ -338,7 +333,6 @@ __launch_bounds__(unary_kernel_threads)
__global__ void gated_act_kernel(const InputType *input, __global__ void gated_act_kernel(const InputType *input,
OutputType *output, OutputType *output,
const ComputeType *scale, const ComputeType *scale,
ComputeType *scale_inv,
ComputeType *amax, ComputeType *amax,
const size_t m, const size_t m,
const size_t n, const size_t n,
...@@ -356,9 +350,6 @@ __global__ void gated_act_kernel(const InputType *input, ...@@ -356,9 +350,6 @@ __global__ void gated_act_kernel(const InputType *input,
ComputeType s = 0; ComputeType s = 0;
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s);
}
} }
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
...@@ -398,7 +389,6 @@ template <int nvec, ...@@ -398,7 +389,6 @@ template <int nvec,
void GatedActivationKernelLauncher(const InputType *input, void GatedActivationKernelLauncher(const InputType *input,
OutputType *output, OutputType *output,
const fp32 *scale, const fp32 *scale,
fp32 *scale_inv,
fp32 *amax, fp32 *amax,
const size_t m, const size_t m,
const size_t n, const size_t n,
...@@ -413,16 +403,16 @@ void GatedActivationKernelLauncher(const InputType *input, ...@@ -413,16 +403,16 @@ void GatedActivationKernelLauncher(const InputType *input,
switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) { switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) {
case Alignment::SAME_ALIGNED: case Alignment::SAME_ALIGNED:
gated_act_kernel<nvec, true, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>( gated_act_kernel<nvec, true, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, m, n, num_aligned_elements); input, output, scale, amax, m, n, num_aligned_elements);
break; break;
case Alignment::SAME_UNALIGNED: case Alignment::SAME_UNALIGNED:
gated_act_kernel<nvec, false, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>( gated_act_kernel<nvec, false, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, m, n, num_aligned_elements); input, output, scale, amax, m, n, num_aligned_elements);
break; break;
case Alignment::DIFFERENT: { case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize // If the pointers are aligned differently we cannot vectorize
gated_act_kernel<1, true, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>( gated_act_kernel<1, true, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, m, n, n); input, output, scale, amax, m, n, n);
break; break;
} }
} }
......
...@@ -341,14 +341,30 @@ def _default_sf_compute( ...@@ -341,14 +341,30 @@ def _default_sf_compute(
return sf return sf
@torch.jit.script
def _compute_scaling_factor_inverse(
scale: torch.Tensor,
scale_inv: torch.Tensor,
non_weight_mask: torch.Tensor,
update_weight_scale_inv: bool,
) -> torch.Tensor:
"""Compute inverse of scaling factor."""
if update_weight_scale_inv:
return 1.0 / scale
return torch.where(non_weight_mask, 1.0 / scale, scale_inv)
@torch.jit.script @torch.jit.script
def fused_amax_and_scale_update( def fused_amax_and_scale_update(
amax_history: torch.Tensor, amax_history: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
scale_inv: torch.Tensor,
fp8_max: float, fp8_max: float,
margin: int, margin: int,
amax_compute_algo: str, amax_compute_algo: str,
) -> Tuple[torch.Tensor, torch.Tensor]: non_weight_mask: torch.Tensor,
update_weight_scale_inv: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Amax to scale conversion.""" """Amax to scale conversion."""
# Get amax from history. # Get amax from history.
...@@ -358,13 +374,23 @@ def fused_amax_and_scale_update( ...@@ -358,13 +374,23 @@ def fused_amax_and_scale_update(
) )
# Calculate new scaling factor. # Calculate new scaling factor.
return amax_history, _default_sf_compute( scale = _default_sf_compute(
amax, amax,
scale, scale,
fp8_max, fp8_max,
margin, margin,
) )
# Calculate new inverse of scaling factor.
scale_inv = _compute_scaling_factor_inverse(
scale,
scale_inv,
non_weight_mask,
update_weight_scale_inv,
)
return amax_history, scale, scale_inv
def _compute_amax( def _compute_amax(
amax_history: torch.Tensor, amax_history: torch.Tensor,
...@@ -403,6 +429,7 @@ def _compute_scaling_factor( ...@@ -403,6 +429,7 @@ def _compute_scaling_factor(
def amax_and_scale_update( def amax_and_scale_update(
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fwd_update: bool, fwd_update: bool,
update_weight_scale_inv: bool = True,
) -> None: ) -> None:
"""Updates fp8 amaxes/scales for fwd | bwd.""" """Updates fp8 amaxes/scales for fwd | bwd."""
amax_compute = fp8_meta["recipe"].amax_compute_algo amax_compute = fp8_meta["recipe"].amax_compute_algo
...@@ -414,12 +441,16 @@ def amax_and_scale_update( ...@@ -414,12 +441,16 @@ def amax_and_scale_update(
( (
fp8_meta[fp8_meta_tensor_key].amax_history, fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale, fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
) = fused_amax_and_scale_update( ) = fused_amax_and_scale_update(
fp8_meta[fp8_meta_tensor_key].amax_history, fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale, fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
fp8_meta[fp8_max_key], fp8_meta[fp8_max_key],
fp8_meta["recipe"].margin, fp8_meta["recipe"].margin,
fp8_meta["recipe"].amax_compute_algo, fp8_meta["recipe"].amax_compute_algo,
fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"],
update_weight_scale_inv,
) )
else: else:
fp8_meta[fp8_meta_tensor_key].amax_history, amax = _compute_amax( fp8_meta[fp8_meta_tensor_key].amax_history, amax = _compute_amax(
...@@ -432,6 +463,12 @@ def amax_and_scale_update( ...@@ -432,6 +463,12 @@ def amax_and_scale_update(
fp8_meta[fp8_max_key], fp8_meta[fp8_max_key],
fp8_meta["recipe"], fp8_meta["recipe"],
) )
fp8_meta[fp8_meta_tensor_key].scale_inv = _compute_scaling_factor_inverse(
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"],
update_weight_scale_inv,
)
def get_fp8_te_dtype( def get_fp8_te_dtype(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment