Unverified Commit c141f8db authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[offload] Add pragma directives to ensure we ignore the backward pass functions. (#675)

* add pragma

* add mypy ignore comments

* fix comment

* add more no cover comments

* add comments
parent 8a42a8e3
......@@ -117,7 +117,10 @@ class ModelShard(nn.Module):
# Restore all the parameter buffers
self.model_shard.to(device=self.device, non_blocking=non_blocking)
def backward_load(self, non_blocking: bool = True) -> None:
# Ignore the following function for code coverage since the backward pass
# is triggered by C++ code and cannot be calculated when overriding
# autograd.Function
def backward_load(self, non_blocking: bool = True) -> None: # pragma: no cover
with torch.cuda.stream(self._cpu_to_gpu_stream):
self.model_shard.to(self.device, non_blocking=non_blocking)
......@@ -125,7 +128,10 @@ class ModelShard(nn.Module):
with torch.cuda.stream(self._gpu_to_cpu_stream):
self.model_shard.to(self.offload_device, non_blocking=non_blocking)
def backward_drop(self, non_blocking: bool = True) -> None:
# Ignore the following function for code coverage since the backward pass
# is triggered by C++ code and cannot be calculated when overriding
# autograd.Function
def backward_drop(self, non_blocking: bool = True) -> None: # pragma: no cover
with torch.cuda.stream(self._gpu_to_cpu_stream):
self.model_shard.to(self.offload_device, non_blocking=non_blocking)
......@@ -206,9 +212,12 @@ class OffloadFunction(torch.autograd.Function):
r.requires_grad = True
return result[0] if len(result) == 1 else result
# Ignore the following function for code coverage since the backward pass
# is triggered by C++ code and cannot be calculated when overriding
# autograd.Function
@staticmethod
@_conditional_amp_bwd_decorator
def backward(ctx, *grad_outputs): # type: ignore
def backward(ctx, *grad_outputs): # type: ignore # pragma: no cover
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
inputs = ctx.inputs
......@@ -324,9 +333,12 @@ class ShardSyncLayer(torch.autograd.Function):
return inputs if isinstance(inputs, tuple) else (inputs,)
# Ignore the following function for code coverage since the backward pass
# is triggered by C++ code and cannot be calculated when overriding
# autograd.Function
@staticmethod
@_conditional_amp_bwd_decorator
def backward(ctx, *grad_outputs): # type: ignore
def backward(ctx, *grad_outputs): # type: ignore # pragma: no cover
load_index = ctx.index
drop_index = load_index + 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment