Unverified Commit 69365f88 authored by guyueh1's avatar guyueh1 Committed by GitHub
Browse files

Bugfixes for LayerNormMLP (#1625)



* Fix GEMM+RS overlap for LayerNormMLP
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>

* Fix error LayerNormMLP param.grad is None
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update dtype for wgrad GEMM
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 77d64552
......@@ -390,7 +390,6 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
fc2_out = ub_obj_fc2out.get_buffer(output_quantizer)
else:
dim_size = list(act_out.size())
dim_size[1] = fc2_weight.size(0)
......@@ -467,11 +466,13 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight,
ln_out,
fc1_weight_final,
fc1_weight,
fc1_bias,
fc1_out,
fc1_out_without_bias,
act_out,
fc2_weight_final,
fc2_weight,
fc2_bias,
mu,
rsigma,
......@@ -584,11 +585,13 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight,
ln_out,
fc1_weight,
origin_fc1_weight,
fc1_bias,
fc1_out,
fc1_out_without_bias,
act_out,
fc2_weight,
origin_fc2_weight,
fc2_bias,
mu,
rsigma,
......@@ -607,7 +610,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
fc2_weight_main_grad = (
ctx.fc2_main_grad
if fc2_weight is not None
if origin_fc2_weight is not None
and ctx.fuse_wgrad_accumulation
and ctx.fc2_weight_requires_grad
else None
......@@ -616,8 +619,8 @@ class _LayerNormMLP(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
if ctx.fuse_wgrad_accumulation:
fc1_weight.main_grad = fc1_weight_main_grad
fc2_weight.main_grad = fc2_weight_main_grad
origin_fc1_weight.main_grad = fc1_weight_main_grad
origin_fc2_weight.main_grad = fc2_weight_main_grad
# TODO: Fix this # pylint: disable=fixme
# Gather saved autograd context tensors when running with FSDP
......@@ -735,7 +738,7 @@ class _LayerNormMLP(torch.autograd.Function):
grad_output,
get_workspace(),
out_dtype=(
fc2_weight.main_grad.dtype
origin_fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
......@@ -745,7 +748,7 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc2_bias if fc2_bias_grad is None else None,
accumulate=accumulate_wgrad_into_param_main_grad,
use_split_accumulator=_2X_ACC_WGRAD,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
if fc2_bias_grad is None:
fc2_bias_grad = fc2_bias_grad_
......@@ -899,7 +902,7 @@ class _LayerNormMLP(torch.autograd.Function):
dact,
get_workspace(),
out_dtype=(
fc1_weight.main_grad.dtype
origin_fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
......@@ -907,7 +910,7 @@ class _LayerNormMLP(torch.autograd.Function):
grad=fuse_gemm_and_bias_fc1_wgrad,
bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
accumulate=accumulate_wgrad_into_param_main_grad,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
out=origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub=ub_obj_fc1_wgrad,
ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None,
extra_output=fc1_dgrad_rs_out,
......@@ -968,16 +971,21 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fc1_weight_requires_grad:
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, "grad_added_to_main_grad"):
fc1_weight.grad_added_to_main_grad = True
if getattr(fc1_weight, "zero_out_wgrad", False):
origin_fc1_weight.grad_added_to_main_grad = True
if getattr(origin_fc1_weight, "zero_out_wgrad", False):
fc1_wgrad = torch.zeros(
fc1_weight.main_grad.shape,
dtype=fc1_weight.dtype,
origin_fc1_weight.main_grad.shape,
dtype=origin_fc1_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
fc1_wgrad = None
fc1_wgrad = torch.empty(
origin_fc1_weight.main_grad.shape,
dtype=origin_fc1_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
fc1_wgrad = None
else:
......@@ -985,17 +993,24 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fc2_weight_requires_grad:
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, "grad_added_to_main_grad"):
fc2_weight.grad_added_to_main_grad = True
if getattr(fc2_weight, "zero_out_wgrad", False):
if ctx.fuse_wgrad_accumulation and hasattr(
origin_fc2_weight, "grad_added_to_main_grad"
):
origin_fc2_weight.grad_added_to_main_grad = True
if getattr(origin_fc2_weight, "zero_out_wgrad", False):
fc2_wgrad = torch.zeros(
fc2_weight.main_grad.shape,
dtype=fc2_weight.dtype,
origin_fc2_weight.main_grad.shape,
dtype=origin_fc2_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
fc2_wgrad = None
fc2_wgrad = torch.empty(
origin_fc2_weight.main_grad.shape,
dtype=origin_fc2_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
fc2_wgrad = None
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment