"vscode:/vscode.git/clone" did not exist on "825d8892b54af80516ce98a89b595018f756a6d3"
Unverified Commit aaba1b01 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

wgrad-accumulation patch for legacy megatron-LM (#451)



* wgrad-accumulation patch for legacy megatron-LM
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix fused attention tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a29e2e11
......@@ -648,7 +648,7 @@ class _dpa_fp8(torch.autograd.Function):
ZInv = None
philox_unpacked = None
qkv_out = ext.fp8_gemm(
qkv_out, _ = ext.fp8_gemm(
qkv_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
......@@ -798,7 +798,7 @@ class _dpa_fp8(torch.autograd.Function):
)
# QKV DGRAD
qkv_dgrad = ext.fp8_gemm(
qkv_dgrad, _ = ext.fp8_gemm(
qkv_weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
......@@ -812,7 +812,7 @@ class _dpa_fp8(torch.autograd.Function):
use_split_accumulator=_2X_ACC_DGRAD,
)
# QKV WGRAD
qkv_wgrad = ext.fp8_gemm(
qkv_wgrad, _ = ext.fp8_gemm(
inputmat_t,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
......
......@@ -473,14 +473,20 @@ class _LayerNormLinear(torch.autograd.Function):
if not ctx.use_bias:
grad_bias = None
if weight.requires_grad:
# Handle custom DDP from mcore.
weight.grad_added_to_main_grad = ctx.fuse_wgrad_accumulation
if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'):
weight.grad_added_to_main_grad = True
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
wgrad = None
return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
dbeta,
wgrad if weight.requires_grad else None,
wgrad,
None,
None,
grad_bias,
......
......@@ -794,20 +794,34 @@ class _LayerNormMLP(torch.autograd.Function):
)
dbeta = None
if fc1_weight.requires_grad:
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, 'grad_added_to_main_grad'):
fc1_weight.grad_added_to_main_grad = True
elif ctx.fuse_wgrad_accumulation:
fc1_wgrad = None
else:
fc1_wgrad = None
if fc2_weight.requires_grad:
# Handle custom DDP from mcore.
fc1_weight.grad_added_to_main_grad = ctx.fuse_wgrad_accumulation
fc2_weight.grad_added_to_main_grad = ctx.fuse_wgrad_accumulation
if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, 'grad_added_to_main_grad'):
fc2_weight.grad_added_to_main_grad = True
elif ctx.fuse_wgrad_accumulation:
fc2_wgrad = None
else:
fc2_wgrad = None
return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
dbeta,
fc1_wgrad if fc1_weight.requires_grad else None,
fc1_wgrad,
None,
None,
fc1_bias_grad if ctx.use_fc1_bias else None,
None,
fc2_wgrad if fc2_weight.requires_grad else None,
fc2_wgrad,
None,
None,
fc2_bias_grad if ctx.use_fc2_bias else None,
......
......@@ -415,11 +415,17 @@ class _Linear(torch.autograd.Function):
if not ctx.use_bias:
grad_bias = None
if weight.requires_grad:
# Handle custom DDP from mcore.
weight.grad_added_to_main_grad = ctx.fuse_wgrad_accumulation
if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'):
weight.grad_added_to_main_grad = True
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
wgrad = None
return (
wgrad if weight.requires_grad else None,
wgrad,
None,
None,
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment