Unverified Commit c2937c5a authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] `te.Linear` FP8 DGRAD+RS output bugfix (#1412)



* corrected RS overlap BF16 output clashing with Float8Tensor constructor
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed empty dgrad buffer dtype at initialization
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 3d63cbb4
......@@ -506,13 +506,6 @@ class _Linear(torch.autograd.Function):
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
dgrad = ub_obj_wgrad.get_ubuf_output(1)
if dgrad is None:
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
dgrad_shape[0] = dgrad_shape[0] * tp_world_size
dgrad = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
)
(
grad_output,
grad_output_c,
......@@ -550,6 +543,7 @@ class _Linear(torch.autograd.Function):
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
output_dtype = ctx.activation_dtype
if ctx.requires_dgrad:
if ctx.fp8:
if ctx.is_input_fp8 or (
......@@ -570,6 +564,14 @@ class _Linear(torch.autograd.Function):
None,
ctx.activation_dtype,
)
if dgrad is None:
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
dgrad_shape[0] = dgrad_shape[0] * tp_world_size
dgrad = torch.empty(dgrad_shape, dtype=output_dtype, device=grad_output.device)
if ctx.requires_dgrad:
if ctx.fp8:
_ = fp8_gemm(
weight_fp8.transpose_2d(),
weight_fp8._scale_inv,
......@@ -593,8 +595,7 @@ class _Linear(torch.autograd.Function):
if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out
if output_dtype == torch.uint8:
elif output_dtype == torch.uint8:
dgrad = Float8Tensor(
data=dgrad,
fp8_meta=ctx.fp8_meta,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment