Unverified Commit 69365f88 authored by guyueh1's avatar guyueh1 Committed by GitHub
Browse files

Bugfixes for LayerNormMLP (#1625)



* Fix GEMM+RS overlap for LayerNormMLP
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>

* Fix error LayerNormMLP param.grad is None
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update dtype for wgrad GEMM
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 77d64552
...@@ -390,7 +390,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -390,7 +390,6 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size[0] = dim_size[0] // tp_world_size dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0) dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
fc2_out = ub_obj_fc2out.get_buffer(output_quantizer)
else: else:
dim_size = list(act_out.size()) dim_size = list(act_out.size())
dim_size[1] = fc2_weight.size(0) dim_size[1] = fc2_weight.size(0)
...@@ -467,11 +466,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -467,11 +466,13 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight, ln_weight,
ln_out, ln_out,
fc1_weight_final, fc1_weight_final,
fc1_weight,
fc1_bias, fc1_bias,
fc1_out, fc1_out,
fc1_out_without_bias, fc1_out_without_bias,
act_out, act_out,
fc2_weight_final, fc2_weight_final,
fc2_weight,
fc2_bias, fc2_bias,
mu, mu,
rsigma, rsigma,
...@@ -584,11 +585,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -584,11 +585,13 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight, ln_weight,
ln_out, ln_out,
fc1_weight, fc1_weight,
origin_fc1_weight,
fc1_bias, fc1_bias,
fc1_out, fc1_out,
fc1_out_without_bias, fc1_out_without_bias,
act_out, act_out,
fc2_weight, fc2_weight,
origin_fc2_weight,
fc2_bias, fc2_bias,
mu, mu,
rsigma, rsigma,
...@@ -607,7 +610,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -607,7 +610,7 @@ class _LayerNormMLP(torch.autograd.Function):
) )
fc2_weight_main_grad = ( fc2_weight_main_grad = (
ctx.fc2_main_grad ctx.fc2_main_grad
if fc2_weight is not None if origin_fc2_weight is not None
and ctx.fuse_wgrad_accumulation and ctx.fuse_wgrad_accumulation
and ctx.fc2_weight_requires_grad and ctx.fc2_weight_requires_grad
else None else None
...@@ -616,8 +619,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -616,8 +619,8 @@ class _LayerNormMLP(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one. # we need to connect them into one.
if ctx.fuse_wgrad_accumulation: if ctx.fuse_wgrad_accumulation:
fc1_weight.main_grad = fc1_weight_main_grad origin_fc1_weight.main_grad = fc1_weight_main_grad
fc2_weight.main_grad = fc2_weight_main_grad origin_fc2_weight.main_grad = fc2_weight_main_grad
# TODO: Fix this # pylint: disable=fixme # TODO: Fix this # pylint: disable=fixme
# Gather saved autograd context tensors when running with FSDP # Gather saved autograd context tensors when running with FSDP
...@@ -735,7 +738,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -735,7 +738,7 @@ class _LayerNormMLP(torch.autograd.Function):
grad_output, grad_output,
get_workspace(), get_workspace(),
out_dtype=( out_dtype=(
fc2_weight.main_grad.dtype origin_fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype else ctx.activation_dtype
), ),
...@@ -745,7 +748,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -745,7 +748,7 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc2_bias if fc2_bias_grad is None else None, bias=fc2_bias if fc2_bias_grad is None else None,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
if fc2_bias_grad is None: if fc2_bias_grad is None:
fc2_bias_grad = fc2_bias_grad_ fc2_bias_grad = fc2_bias_grad_
...@@ -899,7 +902,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -899,7 +902,7 @@ class _LayerNormMLP(torch.autograd.Function):
dact, dact,
get_workspace(), get_workspace(),
out_dtype=( out_dtype=(
fc1_weight.main_grad.dtype origin_fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype else ctx.activation_dtype
), ),
...@@ -907,7 +910,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -907,7 +910,7 @@ class _LayerNormMLP(torch.autograd.Function):
grad=fuse_gemm_and_bias_fc1_wgrad, grad=fuse_gemm_and_bias_fc1_wgrad,
bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub=ub_obj_fc1_wgrad, ub=ub_obj_fc1_wgrad,
ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None, ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None,
extra_output=fc1_dgrad_rs_out, extra_output=fc1_dgrad_rs_out,
...@@ -968,16 +971,21 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -968,16 +971,21 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fc1_weight_requires_grad: if ctx.fc1_weight_requires_grad:
# Handle custom DDP from mcore. # Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, "grad_added_to_main_grad"): if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, "grad_added_to_main_grad"):
fc1_weight.grad_added_to_main_grad = True origin_fc1_weight.grad_added_to_main_grad = True
if getattr(fc1_weight, "zero_out_wgrad", False): if getattr(origin_fc1_weight, "zero_out_wgrad", False):
fc1_wgrad = torch.zeros( fc1_wgrad = torch.zeros(
fc1_weight.main_grad.shape, origin_fc1_weight.main_grad.shape,
dtype=fc1_weight.dtype, dtype=origin_fc1_weight.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False, requires_grad=False,
) )
else: else:
fc1_wgrad = None fc1_wgrad = torch.empty(
origin_fc1_weight.main_grad.shape,
dtype=origin_fc1_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation: elif ctx.fuse_wgrad_accumulation:
fc1_wgrad = None fc1_wgrad = None
else: else:
...@@ -985,17 +993,24 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -985,17 +993,24 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fc2_weight_requires_grad: if ctx.fc2_weight_requires_grad:
# Handle custom DDP from mcore. # Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, "grad_added_to_main_grad"): if ctx.fuse_wgrad_accumulation and hasattr(
fc2_weight.grad_added_to_main_grad = True origin_fc2_weight, "grad_added_to_main_grad"
if getattr(fc2_weight, "zero_out_wgrad", False): ):
origin_fc2_weight.grad_added_to_main_grad = True
if getattr(origin_fc2_weight, "zero_out_wgrad", False):
fc2_wgrad = torch.zeros( fc2_wgrad = torch.zeros(
fc2_weight.main_grad.shape, origin_fc2_weight.main_grad.shape,
dtype=fc2_weight.dtype, dtype=origin_fc2_weight.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False, requires_grad=False,
) )
else: else:
fc2_wgrad = None fc2_wgrad = torch.empty(
origin_fc2_weight.main_grad.shape,
dtype=origin_fc2_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation: elif ctx.fuse_wgrad_accumulation:
fc2_wgrad = None fc2_wgrad = None
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment