Unverified Commit ba5dc5dd authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Enable reuse of dummy wgrad tensor (#1651)



* Use dummy wgrads for lower memory consumption
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix to avoid sharing gradients.
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Disable automatic use of batch_p2p_comm for CP2
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Change weight to origin_weight for LN_LINEAR
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent db2aaa9e
......@@ -616,7 +616,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
rank = get_distributed_rank(cp_group)
send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0"))
causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type
......@@ -1564,7 +1564,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0"))
q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = (
restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
......
......@@ -43,6 +43,7 @@ _2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_multi_stream_cublas_workspace = []
_dummy_wgrads = {}
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
......@@ -78,6 +79,22 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
return _multi_stream_cublas_workspace
def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor:
"""Returns a dummy tensor of given shape."""
assert len(shape) == 2
global _dummy_wgrads
if (shape[0], shape[1], dtype) not in _dummy_wgrads:
_dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty(
shape,
dtype=dtype,
device="cuda",
requires_grad=False,
)
if zero:
_dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0)
return _dummy_wgrads[(shape[0], shape[1], dtype)].detach()
def initialize_ub(
shape: list,
tp_size: int,
......
......@@ -19,6 +19,7 @@ from .base import (
get_workspace,
get_ub,
TransformerEngineBaseModule,
get_dummy_wgrad,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
......@@ -796,18 +797,15 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"):
origin_weight.grad_added_to_main_grad = True
if getattr(origin_weight, "zero_out_wgrad", False):
wgrad = torch.zeros(
origin_weight.main_grad.shape,
dtype=origin_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(origin_weight.main_grad.shape),
origin_weight.dtype,
zero=True,
)
else:
wgrad = torch.empty(
origin_weight.main_grad.shape,
dtype=origin_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(origin_weight.main_grad.shape),
origin_weight.dtype,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
......
......@@ -16,6 +16,7 @@ from .base import (
get_workspace,
get_ub,
TransformerEngineBaseModule,
get_dummy_wgrad,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
......@@ -688,18 +689,15 @@ class _Linear(torch.autograd.Function):
):
weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False):
wgrad = torch.zeros(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
zero=True,
)
else:
wgrad = torch.empty(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment