Commit 686e93cd authored by yuguo's avatar yuguo
Browse files

Merge branch 'develop_v2.5_swap' into 'develop_v2.5'

add swap env

See merge request dcutoolkit/deeplearing/TransformerEngine!40
parents c4bb6049 d19a5a44
...@@ -427,7 +427,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -427,7 +427,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading: if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad: if ctx.grad_added_to_main_grad:
...@@ -556,7 +556,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -556,7 +556,7 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one. # we need to connect them into one.
if ctx.cpu_offloading: if ctx.cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
if ctx.grad_added_to_main_grad: if ctx.grad_added_to_main_grad:
origin_weight = ctx.weight_object origin_weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
......
...@@ -368,7 +368,7 @@ class _Linear(torch.autograd.Function): ...@@ -368,7 +368,7 @@ class _Linear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading: if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad: if ctx.grad_added_to_main_grad:
...@@ -459,7 +459,7 @@ class _Linear(torch.autograd.Function): ...@@ -459,7 +459,7 @@ class _Linear(torch.autograd.Function):
else None else None
) )
if ctx.cpu_offloading: if ctx.cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
if ctx.grad_added_to_main_grad: if ctx.grad_added_to_main_grad:
weight = ctx.weight_object weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment