"tests/vscode:/vscode.git/clone" did not exist on "e3cec88aa5b7ac391e4aa6dc9b6388100d59d8f9"
Unverified Commit 02a3582c authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Move calculation of scale inverse to framework (#51)



* Move scale inverse calculation to framework
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix RMSNorm
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix gated kernel/geglu
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 40467fc2
......@@ -115,8 +115,6 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
......
......@@ -131,8 +131,6 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
......
......@@ -70,8 +70,6 @@ void performTestGEGLU(const size_t N, const size_t H) {
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol);
......
......@@ -70,8 +70,6 @@ void performTestGelu(const size_t N, const size_t H) {
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol);
......
......@@ -215,8 +215,6 @@ void performTest(const size_t N, const size_t H) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
......
......@@ -27,7 +27,6 @@ void compute_ref(const std::vector<std::vector<InputType>>& input_list,
std::vector<std::vector<OutputType>>& output_t_list,
const std::vector<float>& scale_list,
std::vector<float>& amax_list,
std::vector<float>& scale_inv_list,
const std::vector<size_t>& height_list,
const std::vector<size_t>& width_list) {
using compute_t = float;
......@@ -37,10 +36,8 @@ void compute_ref(const std::vector<std::vector<InputType>>& input_list,
auto& output_t = output_t_list[tensor_id];
const compute_t scale = scale_list[tensor_id];
compute_t& amax = amax_list[tensor_id];
compute_t& scale_inv = scale_inv_list[tensor_id];
const size_t height = height_list[tensor_id];
const size_t width = width_list[tensor_id];
scale_inv = 1. / scale;
amax = -1e100;
for (size_t i = 0; i < height; ++i) {
for (size_t j = 0; j < width; ++j) {
......@@ -76,8 +73,7 @@ void performTest() {
// Buffers for reference implementation
std::vector<std::vector<InputType>> ref_input_list;
std::vector<std::vector<OutputType>> ref_output_c_list, ref_output_t_list;
std::vector<float> ref_scale_list(num_tensors), ref_amax_list(num_tensors),
ref_scale_inv_list(num_tensors);
std::vector<float> ref_scale_list(num_tensors), ref_amax_list(num_tensors);
std::vector<size_t> ref_height_list(num_tensors), ref_width_list(num_tensors);
// Initialize buffers
......@@ -128,7 +124,6 @@ void performTest() {
ref_output_t_list,
ref_scale_list,
ref_amax_list,
ref_scale_inv_list,
ref_height_list,
ref_width_list);
......@@ -143,10 +138,6 @@ void performTest() {
output_c_list[tensor_id].amax(),
ref_amax_list[tensor_id],
atol_amax, rtol_amax);
compareResults("scale_inv",
output_c_list[tensor_id].scale_inv(),
ref_scale_inv_list[tensor_id],
atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c",
......
......@@ -78,8 +78,6 @@ void performTestQ(const size_t N) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_q", output, ref_output.get(), atol, rtol);
}
......
......@@ -172,8 +172,6 @@ void performTest(const size_t N, const size_t H) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
......
......@@ -43,7 +43,6 @@ void gelu_cast(const Tensor &input,
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->scale_inv.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
{},
......@@ -71,7 +70,6 @@ void geglu_cast(const Tensor &input,
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->scale_inv.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
output->data.shape[0],
output->data.shape[1],
......
......@@ -87,10 +87,6 @@ struct FwdParams : public ParamsBase {
// Scaling factor
void *scale;
// Scaling factor inverse,
// needed for cublasLt fp8 gemm
void *scale_inv;
// AMax output
void *amax;
......
......@@ -207,7 +207,6 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out;
// Query the kernel-specific launch parameters.
......
......@@ -133,7 +133,6 @@ void ln_fwd_tuned_kernel(FwdParams params) {
if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t*>(params.amax), amax);
reciprocal<compute_t>(reinterpret_cast<compute_t*>(params.scale_inv), scale);
}
}
}
......@@ -302,7 +301,6 @@ void ln_fwd_general_kernel(FwdParams params) {
if ( threadIdx.x == 0 ) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t*>(params.amax), amax);
reciprocal<compute_t>(reinterpret_cast<compute_t*>(params.scale_inv), scale);
}
}
}
......
......@@ -154,7 +154,6 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out;
// Query the kernel-specific launch parameters.
......
......@@ -125,7 +125,6 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
......@@ -266,7 +265,6 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
......
......@@ -62,7 +62,6 @@ cast_transpose_kernel(const IType * const input,
OType * const output_t,
const CType * const scale_ptr,
CType * const amax,
CType * const scale_inv,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
......@@ -159,7 +158,6 @@ cast_transpose_kernel(const IType * const input,
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) atomicMaxFloat(amax, max);
if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale);
}
}
......@@ -171,7 +169,6 @@ cast_transpose_kernel_notaligned(const IType * const input,
OType * const output_t,
const CType * const scale_ptr,
CType * const amax,
CType * const scale_inv,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
......@@ -295,7 +292,6 @@ cast_transpose_kernel_notaligned(const IType * const input,
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) atomicMaxFloat(amax, max);
if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale);
}
}
......@@ -324,8 +320,6 @@ void cast_transpose(const Tensor &input,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
"C and T outputs need to share scale inverse tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
......@@ -361,7 +355,6 @@ void cast_transpose(const Tensor &input,
reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(cast_transpose_kernel_notaligned<nvec_in, nvec_out, fp32,
......@@ -379,7 +372,6 @@ void cast_transpose(const Tensor &input,
reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles);
}
); // NOLINT(*)
......
......@@ -114,7 +114,6 @@ struct CTDBiasParam {
OType *output_t;
const CType *scale_ptr;
CType *amax;
CType *scale_inv;
CType *workspace;
};
......@@ -130,7 +129,6 @@ struct CTDBiasDGeluParam {
OType *output_t;
const CType *scale_ptr;
CType *amax;
CType *scale_inv;
CType *workspace;
};
......@@ -273,7 +271,6 @@ cast_transpose_dbias_kernel(const Param param,
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
}
}
......@@ -442,7 +439,6 @@ cast_transpose_dbias_kernel_notaligned(const Param param,
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
}
}
......@@ -555,8 +551,6 @@ void cast_transpose_dbias(const Tensor &input,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
"C and T outputs need to share scale inverse tensor.");
NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
......@@ -597,7 +591,6 @@ void cast_transpose_dbias(const Tensor &input,
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(cast_output->amax.dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) {
......@@ -782,7 +775,6 @@ cast_transpose_dbias_dgelu_kernel(const Param param,
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
}
}
......@@ -973,7 +965,6 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
}
}
......@@ -1382,8 +1373,6 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
"C and T outputs need to share scale inverse tensor.");
NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
......@@ -1432,7 +1421,6 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(cast_output->amax.dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) {
cudaFuncSetAttribute(cast_transpose_dbias_dgelu_kernel<nvec_in, nvec_out, Param>,
......
......@@ -34,8 +34,6 @@ struct MultiCastTransposeArgs {
void* scale_list[kMaxTensorsPerKernel];
// (output) AMAX's of input tensors
void* amax_list[kMaxTensorsPerKernel];
// (output) Reciprocal of scaling factors
void* scale_inv_list[kMaxTensorsPerKernel];
// Input matrix heights
int num_rows_list[kMaxTensorsPerKernel];
// Input matrix widths
......@@ -90,7 +88,6 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
const CType* scale_ptr = reinterpret_cast<CType*>(args.scale_list[tensor_id]);
const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr;
CType* amax = reinterpret_cast<CType*>(args.amax_list[tensor_id]);
CType* scale_inv = reinterpret_cast<CType*>(args.scale_inv_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
......@@ -193,9 +190,6 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) atomicMaxFloat(amax, local_amax);
}
if (tid == 0 && tile_id == 0) {
if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale);
}
}
} // namespace
......@@ -300,7 +294,6 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr;
kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr;
kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr;
kernel_args.scale_inv_list[pos] = cast_output_list[tensor_id]->scale_inv.dptr;
kernel_args.num_rows_list[pos] = num_rows;
kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos+1] = kernel_args.block_range[pos] + num_tiles;
......
......@@ -50,7 +50,6 @@ void fp8_quantize(const Tensor &input,
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->scale_inv.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
N,
{},
......@@ -82,7 +81,6 @@ void fp8_dequantize(const Tensor &input,
reinterpret_cast<OType*>(output->data.dptr),
nullptr,
nullptr,
nullptr,
N,
p,
stream);
......
......@@ -185,7 +185,6 @@ __launch_bounds__(unary_kernel_threads)
__global__ void unary_kernel(const InputType *input,
OutputType *output,
const ComputeType *scale,
ComputeType *scale_inv,
ComputeType *amax,
Param p,
const size_t N,
......@@ -196,9 +195,6 @@ __global__ void unary_kernel(const InputType *input,
ComputeType s = 0;
if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale;
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s);
}
}
const int warp_id = threadIdx.x / THREADS_PER_WARP;
......@@ -295,7 +291,6 @@ template <int nvec, typename Param,
void VectorizedUnaryKernelLauncher(const InputType *input,
OutputType *output,
const fp32 *scale,
fp32 *scale_inv,
fp32 *amax,
const size_t N,
const Param params,
......@@ -313,16 +308,16 @@ void VectorizedUnaryKernelLauncher(const InputType *input,
switch (align) {
case Alignment::SAME_ALIGNED:
unary_kernel<nvec, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, params, N, num_aligned_elements);
input, output, scale, amax, params, N, num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
unary_kernel<nvec, false, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, params, N, num_aligned_elements);
input, output, scale, amax, params, N, num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
unary_kernel<1, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, params, N, N);
input, output, scale, amax, params, N, N);
break;
}
}
......@@ -338,7 +333,6 @@ __launch_bounds__(unary_kernel_threads)
__global__ void gated_act_kernel(const InputType *input,
OutputType *output,
const ComputeType *scale,
ComputeType *scale_inv,
ComputeType *amax,
const size_t m,
const size_t n,
......@@ -356,9 +350,6 @@ __global__ void gated_act_kernel(const InputType *input,
ComputeType s = 0;
if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale;
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s);
}
}
const int warp_id = threadIdx.x / THREADS_PER_WARP;
......@@ -398,7 +389,6 @@ template <int nvec,
void GatedActivationKernelLauncher(const InputType *input,
OutputType *output,
const fp32 *scale,
fp32 *scale_inv,
fp32 *amax,
const size_t m,
const size_t n,
......@@ -413,16 +403,16 @@ void GatedActivationKernelLauncher(const InputType *input,
switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) {
case Alignment::SAME_ALIGNED:
gated_act_kernel<nvec, true, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, m, n, num_aligned_elements);
input, output, scale, amax, m, n, num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
gated_act_kernel<nvec, false, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, m, n, num_aligned_elements);
input, output, scale, amax, m, n, num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
gated_act_kernel<1, true, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, m, n, n);
input, output, scale, amax, m, n, n);
break;
}
}
......
......@@ -341,14 +341,30 @@ def _default_sf_compute(
return sf
@torch.jit.script
def _compute_scaling_factor_inverse(
scale: torch.Tensor,
scale_inv: torch.Tensor,
non_weight_mask: torch.Tensor,
update_weight_scale_inv: bool,
) -> torch.Tensor:
"""Compute inverse of scaling factor."""
if update_weight_scale_inv:
return 1.0 / scale
return torch.where(non_weight_mask, 1.0 / scale, scale_inv)
@torch.jit.script
def fused_amax_and_scale_update(
amax_history: torch.Tensor,
scale: torch.Tensor,
scale_inv: torch.Tensor,
fp8_max: float,
margin: int,
amax_compute_algo: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
non_weight_mask: torch.Tensor,
update_weight_scale_inv: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Amax to scale conversion."""
# Get amax from history.
......@@ -358,13 +374,23 @@ def fused_amax_and_scale_update(
)
# Calculate new scaling factor.
return amax_history, _default_sf_compute(
scale = _default_sf_compute(
amax,
scale,
fp8_max,
margin,
)
# Calculate new inverse of scaling factor.
scale_inv = _compute_scaling_factor_inverse(
scale,
scale_inv,
non_weight_mask,
update_weight_scale_inv,
)
return amax_history, scale, scale_inv
def _compute_amax(
amax_history: torch.Tensor,
......@@ -403,6 +429,7 @@ def _compute_scaling_factor(
def amax_and_scale_update(
fp8_meta: Dict[str, Any],
fwd_update: bool,
update_weight_scale_inv: bool = True,
) -> None:
"""Updates fp8 amaxes/scales for fwd | bwd."""
amax_compute = fp8_meta["recipe"].amax_compute_algo
......@@ -414,12 +441,16 @@ def amax_and_scale_update(
(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
) = fused_amax_and_scale_update(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
fp8_meta[fp8_max_key],
fp8_meta["recipe"].margin,
fp8_meta["recipe"].amax_compute_algo,
fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"],
update_weight_scale_inv,
)
else:
fp8_meta[fp8_meta_tensor_key].amax_history, amax = _compute_amax(
......@@ -432,6 +463,12 @@ def amax_and_scale_update(
fp8_meta[fp8_max_key],
fp8_meta["recipe"],
)
fp8_meta[fp8_meta_tensor_key].scale_inv = _compute_scaling_factor_inverse(
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"],
update_weight_scale_inv,
)
def get_fp8_te_dtype(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment