Commit d19a5a44 authored by evt_fugx1's avatar evt_fugx1
Browse files

add swap env

parent c4bb6049
......@@ -427,7 +427,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading:
if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
......@@ -556,7 +556,7 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
if ctx.cpu_offloading:
if ctx.cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
if ctx.grad_added_to_main_grad:
origin_weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
......
......@@ -368,7 +368,7 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading:
if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
......@@ -459,7 +459,7 @@ class _Linear(torch.autograd.Function):
else None
)
if ctx.cpu_offloading:
if ctx.cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
if ctx.grad_added_to_main_grad:
weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment