Unverified Commit 1321b9b5 authored by guyueh1's avatar guyueh1 Committed by GitHub
Browse files

Ensure weight transpose is valid for Hopper FP8 training (#1596)



* Update usage of weightmat before saving for backward
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix for layernorm mlp
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>

---------
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent e14d1472
...@@ -341,7 +341,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -341,7 +341,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight.requires_grad and parallel_mode == "column" and sequence_parallel weight.requires_grad and parallel_mode == "column" and sequence_parallel
) )
# Input with column-wise usage is needed for dgrad GEMM. # Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input: if backward_needs_input:
if isinstance(ln_out, QuantizedTensor): if isinstance(ln_out, QuantizedTensor):
# For sequence parallel in vanilla FP8, rowwise data is # For sequence parallel in vanilla FP8, rowwise data is
...@@ -350,6 +350,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -350,6 +350,11 @@ class _LayerNormLinear(torch.autograd.Function):
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
ln_out.update_usage(rowwise_usage=False) ln_out.update_usage(rowwise_usage=False)
# Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad:
if isinstance(weightmat, QuantizedTensor):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading: if cpu_offloading:
if fp8 and weightmat is not None: if fp8 and weightmat is not None:
set_offloading_param(weightmat, "weight_offloading", True) set_offloading_param(weightmat, "weight_offloading", True)
......
...@@ -442,6 +442,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -442,6 +442,14 @@ class _LayerNormMLP(torch.autograd.Function):
ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None, ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None,
extra_output=rs_out, extra_output=rs_out,
) )
# Weight with column-wise usage is needed for dgrad GEMM.
if is_grad_enabled and inp.requires_grad:
if isinstance(fc1_weight_final, QuantizedTensor):
fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensor):
fc2_weight_final.update_usage(columnwise_usage=True)
if not is_grad_enabled: if not is_grad_enabled:
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
else: else:
......
...@@ -272,6 +272,11 @@ class _Linear(torch.autograd.Function): ...@@ -272,6 +272,11 @@ class _Linear(torch.autograd.Function):
inputmat.update_usage(rowwise_usage=False) inputmat.update_usage(rowwise_usage=False)
saved_inputmat = inputmat saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad:
if isinstance(weightmat, QuantizedTensor):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading: if cpu_offloading:
set_offloading_param(weight, "weight_offloading", True) set_offloading_param(weight, "weight_offloading", True)
set_offloading_param(weightmat, "weight_offloading", True) set_offloading_param(weightmat, "weight_offloading", True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment