Unverified Commit f858dc35 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Rename `do_not_clear` to `_do_not_clear` (#1977)


Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent cb5013bd
...@@ -431,7 +431,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -431,7 +431,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
tensor = self.fp8_tensor_object_map.pop(tensor_tag) tensor = self.fp8_tensor_object_map.pop(tensor_tag)
if self.double_buffering: if self.double_buffering:
tensor.do_not_clear = True tensor._do_not_clear = True
self.tensor_tag_to_buf.pop(tensor_tag, None) self.tensor_tag_to_buf.pop(tensor_tag, None)
# the tensor should have been copied back in on_group_commit_backward() # the tensor should have been copied back in on_group_commit_backward()
......
...@@ -101,7 +101,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -101,7 +101,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Mark input tensors as not deletable in backward # Mark input tensors as not deletable in backward
for tensor in (input_,) + params_and_extra_inputs: for tensor in (input_,) + params_and_extra_inputs:
tensor.do_not_clear = True tensor._do_not_clear = True
# Unflatten list of parameters and extra tensor inputs # Unflatten list of parameters and extra tensor inputs
extra_inputs = params_and_extra_inputs[-fuser.num_extra_inputs :] extra_inputs = params_and_extra_inputs[-fuser.num_extra_inputs :]
......
...@@ -41,10 +41,10 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: ...@@ -41,10 +41,10 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
for t in tensors: for t in tensors:
if t is not None: if t is not None:
# Workaround for double buffering in cpu offload # Workaround for double buffering in cpu offload
if hasattr(t, "do_not_clear"): if hasattr(t, "_do_not_clear"):
continue continue
if hasattr(t, "get_data_tensors"): if hasattr(t, "get_data_tensors"):
if any(hasattr(tensor, "do_not_clear") for tensor in t.get_data_tensors()): if any(hasattr(tensor, "_do_not_clear") for tensor in t.get_data_tensors()):
continue continue
if hasattr(t, "clear"): if hasattr(t, "clear"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment