Unverified Commit 91d52ac7 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Fix pipeline parallel execution by using cloned scale inverse tensors (#659)



Use cloned scale_inv for fp8 cast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d9eb1991
......@@ -484,9 +484,9 @@ class _LayerNormLinear(torch.autograd.Function):
)
clear_tensor_data(ln_out_total_t, grad_output_t)
else:
ln_out_total_c = tex.cast_from_fp8(
ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts(
ln_out_total,
ctx.fp8_meta["scaling_fwd"],
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
......
......@@ -676,9 +676,9 @@ class _LayerNormMLP(torch.autograd.Function):
clear_tensor_data(fc1_out)
else:
if fc2_weight.requires_grad:
gelu_out_c = tex.cast_from_fp8(
gelu_out_c = torch.ops.tex_ts.cast_from_fp8_ts(
gelu_out,
ctx.fp8_meta["scaling_fwd"],
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
......@@ -875,9 +875,9 @@ class _LayerNormMLP(torch.autograd.Function):
)
clear_tensor_data(ln_out_total_t, dgelu_t)
else:
ln_out_total_c = tex.cast_from_fp8(
ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts(
ln_out_total,
ctx.fp8_meta["scaling_fwd"],
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment