Unverified Commit c141f8db authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[offload] Add pragma directives to ensure we ignore the backward pass functions. (#675)

* add pragma

* add mypy ignore comments

* fix comment

* add more no cover comments

* add comments
parent 8a42a8e3
...@@ -117,7 +117,10 @@ class ModelShard(nn.Module): ...@@ -117,7 +117,10 @@ class ModelShard(nn.Module):
# Restore all the parameter buffers # Restore all the parameter buffers
self.model_shard.to(device=self.device, non_blocking=non_blocking) self.model_shard.to(device=self.device, non_blocking=non_blocking)
def backward_load(self, non_blocking: bool = True) -> None: # Ignore the following function for code coverage since the backward pass
# is triggered by C++ code and cannot be calculated when overriding
# autograd.Function
def backward_load(self, non_blocking: bool = True) -> None: # pragma: no cover
with torch.cuda.stream(self._cpu_to_gpu_stream): with torch.cuda.stream(self._cpu_to_gpu_stream):
self.model_shard.to(self.device, non_blocking=non_blocking) self.model_shard.to(self.device, non_blocking=non_blocking)
...@@ -125,7 +128,10 @@ class ModelShard(nn.Module): ...@@ -125,7 +128,10 @@ class ModelShard(nn.Module):
with torch.cuda.stream(self._gpu_to_cpu_stream): with torch.cuda.stream(self._gpu_to_cpu_stream):
self.model_shard.to(self.offload_device, non_blocking=non_blocking) self.model_shard.to(self.offload_device, non_blocking=non_blocking)
def backward_drop(self, non_blocking: bool = True) -> None: # Ignore the following function for code coverage since the backward pass
# is triggered by C++ code and cannot be calculated when overriding
# autograd.Function
def backward_drop(self, non_blocking: bool = True) -> None: # pragma: no cover
with torch.cuda.stream(self._gpu_to_cpu_stream): with torch.cuda.stream(self._gpu_to_cpu_stream):
self.model_shard.to(self.offload_device, non_blocking=non_blocking) self.model_shard.to(self.offload_device, non_blocking=non_blocking)
...@@ -206,9 +212,12 @@ class OffloadFunction(torch.autograd.Function): ...@@ -206,9 +212,12 @@ class OffloadFunction(torch.autograd.Function):
r.requires_grad = True r.requires_grad = True
return result[0] if len(result) == 1 else result return result[0] if len(result) == 1 else result
# Ignore the following function for code coverage since the backward pass
# is triggered by C++ code and cannot be calculated when overriding
# autograd.Function
@staticmethod @staticmethod
@_conditional_amp_bwd_decorator @_conditional_amp_bwd_decorator
def backward(ctx, *grad_outputs): # type: ignore def backward(ctx, *grad_outputs): # type: ignore # pragma: no cover
if not torch.autograd._is_checkpoint_valid(): if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
inputs = ctx.inputs inputs = ctx.inputs
...@@ -324,9 +333,12 @@ class ShardSyncLayer(torch.autograd.Function): ...@@ -324,9 +333,12 @@ class ShardSyncLayer(torch.autograd.Function):
return inputs if isinstance(inputs, tuple) else (inputs,) return inputs if isinstance(inputs, tuple) else (inputs,)
# Ignore the following function for code coverage since the backward pass
# is triggered by C++ code and cannot be calculated when overriding
# autograd.Function
@staticmethod @staticmethod
@_conditional_amp_bwd_decorator @_conditional_amp_bwd_decorator
def backward(ctx, *grad_outputs): # type: ignore def backward(ctx, *grad_outputs): # type: ignore # pragma: no cover
load_index = ctx.index load_index = ctx.index
drop_index = load_index + 1 drop_index = load_index + 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment