Unverified Commit 34384e1b authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[offload] Audit OffloadModel API, add error messages and remove redundant code path. (#557)

* renaming/adding error messages

* address comments

* address comments

* add more comments

* add more comments
parent a0458b98
...@@ -49,7 +49,7 @@ def get_model_and_optimizer(args, device, benchmark_config, model_specs): ...@@ -49,7 +49,7 @@ def get_model_and_optimizer(args, device, benchmark_config, model_specs):
optimizer = torch.optim.SGD optimizer = torch.optim.SGD
model = OffloadModel( model = OffloadModel(
model_cpu=model, model=model,
device=torch.device("cuda"), device=torch.device("cuda"),
offload_device=torch.device("cpu"), offload_device=torch.device("cpu"),
num_slices=benchmark_config["slices"], num_slices=benchmark_config["slices"],
......
...@@ -13,7 +13,7 @@ import torch ...@@ -13,7 +13,7 @@ import torch
from torch import nn from torch import nn
def conditional_amp_fwd_decorator(orig_func): # type: ignore def _conditional_amp_fwd_decorator(orig_func): # type: ignore
if hasattr(torch.cuda.amp, "custom_fwd"): if hasattr(torch.cuda.amp, "custom_fwd"):
return torch.cuda.amp.custom_fwd(orig_func) # type: ignore return torch.cuda.amp.custom_fwd(orig_func) # type: ignore
...@@ -25,7 +25,7 @@ def conditional_amp_fwd_decorator(orig_func): # type: ignore ...@@ -25,7 +25,7 @@ def conditional_amp_fwd_decorator(orig_func): # type: ignore
return inner_decorator return inner_decorator
def conditional_amp_bwd_decorator(orig_func): # type: ignore def _conditional_amp_bwd_decorator(orig_func): # type: ignore
if hasattr(torch.cuda.amp, "custom_bwd"): if hasattr(torch.cuda.amp, "custom_bwd"):
return torch.cuda.amp.custom_bwd(orig_func) # type: ignore return torch.cuda.amp.custom_bwd(orig_func) # type: ignore
...@@ -150,7 +150,7 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -150,7 +150,7 @@ class ActivationCheckpointing(torch.autograd.Function):
""" """
@staticmethod @staticmethod
@conditional_amp_fwd_decorator # type: ignore @_conditional_amp_fwd_decorator # type: ignore
def forward(ctx: Any, inputs: Any, dummy_input: Any, model_instance: Any) -> Any: def forward(ctx: Any, inputs: Any, dummy_input: Any, model_instance: Any) -> Any:
inputs = inputs if isinstance(inputs, tuple) else (inputs,) inputs = inputs if isinstance(inputs, tuple) else (inputs,)
...@@ -202,7 +202,7 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -202,7 +202,7 @@ class ActivationCheckpointing(torch.autograd.Function):
return result[0] if len(result) == 1 else result return result[0] if len(result) == 1 else result
@staticmethod @staticmethod
@conditional_amp_bwd_decorator @_conditional_amp_bwd_decorator
def backward(ctx, *grad_outputs): # type: ignore def backward(ctx, *grad_outputs): # type: ignore
if not torch.autograd._is_checkpoint_valid(): if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
...@@ -277,83 +277,6 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -277,83 +277,6 @@ class ActivationCheckpointing(torch.autograd.Function):
return (None, None) + grads return (None, None) + grads
class ShardSyncLayer(torch.autograd.Function):
"""
The shard sync layer is a synchronization point between model shards.
- In the forward pass, it drops parameters in the previous shard and
loads parameters for the next shard.
- In the backward pass, it does the reverse.
It does not change or create any outputs at all, instead it just
forwards the input as the output.
NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
"""
@staticmethod
@conditional_amp_fwd_decorator # type: ignore
def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance: Any) -> Any:
drop_index = index
load_index = index + 1
max_slices = len(model_slices)
if drop_index >= 0:
# Move shard from device to offload device.
logging.info(f"Dropping shard {drop_index}")
model_slices[drop_index].forward_drop()
if load_index < max_slices:
# Load shard from offload device to device.
logging.info(f"Loading shard{load_index}")
model_slices[load_index].forward_load()
ctx.index = index
ctx.model_slices = model_slices
ctx.model_instance = model_instance
return inputs if isinstance(inputs, tuple) else (inputs,)
@staticmethod
@conditional_amp_bwd_decorator
def backward(ctx, *grad_outputs): # type: ignore
load_index = ctx.index
drop_index = load_index + 1
model_slices = ctx.model_slices
model_instance = ctx.model_instance
# TODO(anj-s): Are these redundant in the backward pass?
if drop_index == len(model_slices):
# Drop the last activation since it is still on the CPU
# after the loss.backward() call.
model_instance._activations[-1] = tuple([a.cuda() for a in list(model_instance._activations[-1])])
if drop_index < len(model_slices):
# Move shard from device to offload device.
logging.info(f"Backward Dropping shard {drop_index}")
model_slices[drop_index].backward_drop()
model_instance._activations[drop_index] = tuple(
[a.cpu() for a in list(model_instance._activations[drop_index])]
)
if load_index >= 0:
# Load shard from offload device to device.
logging.info(f"Backward Loading shard{load_index}")
model_slices[load_index].backward_load()
model_instance._activations[load_index] = tuple(
[a.cuda() for a in list(model_instance._activations[load_index])]
)
# The returned variables need to mirror the forward inputs
# TODO(anj-s): Why do we need to do this?
if isinstance(grad_outputs, tuple):
return grad_outputs[0], None, None, None
return grad_outputs, None, None, None
class OffloadModel(nn.Module): class OffloadModel(nn.Module):
"""Wrapper used offload parts of a model to the CPU. """Wrapper used offload parts of a model to the CPU.
...@@ -388,21 +311,31 @@ class OffloadModel(nn.Module): ...@@ -388,21 +311,31 @@ class OffloadModel(nn.Module):
def __init__( def __init__(
self, self,
model_cpu: nn.Sequential, model: nn.Sequential,
device: torch.device, device: torch.device,
offload_device: torch.device = torch.device("cpu"), offload_device: torch.device = torch.device("cpu"),
num_slices: int = 5, num_slices: int = 3,
checkpoint_activation: bool = False, checkpoint_activation: bool = False,
num_microbatches: int = 1, num_microbatches: int = 1,
): ):
super().__init__() super().__init__()
# TODO(anj-s): Add error checks for cuda and sequential model. if not model:
raise TypeError("`model` argument to `OffloadModel` cannot be None.")
if not device:
raise TypeError("`device` argument to `OffloadModel` cannot be None.")
if not isinstance(model, nn.Sequential):
raise TypeError("`model` argument to `OffloadModel` must be of type `nn.Sequential`.")
if not torch.cuda.is_available():
raise TypeError("CUDA must be available as one of the compute devices for `OffloadModel`.")
self.device = device self.device = device
self.offload_device = offload_device self.offload_device = offload_device
# Slice the model into roughly equivalent sequential shards. # Slice the model into roughly equivalent sequential shards.
splits = _split(model_cpu, num_slices) splits = _split(model, num_slices)
# List of model shards that will be placed on/off the device. # List of model shards that will be placed on/off the device.
self.model_slices: List[nn.Module] = [] self.model_slices: List[nn.Module] = []
...@@ -416,7 +349,7 @@ class OffloadModel(nn.Module): ...@@ -416,7 +349,7 @@ class OffloadModel(nn.Module):
) )
# Expose a unified view of the slices # Expose a unified view of the slices
self.model = torch.nn.Sequential(*self.model_slices) self._model = torch.nn.Sequential(*self.model_slices)
# intermediate activations at the slice boundaries. # intermediate activations at the slice boundaries.
self._activations: List[Tuple] = [] self._activations: List[Tuple] = []
...@@ -432,27 +365,10 @@ class OffloadModel(nn.Module): ...@@ -432,27 +365,10 @@ class OffloadModel(nn.Module):
self._num_microbatches = num_microbatches self._num_microbatches = num_microbatches
def forward(self, *inputs: Any, **_: Any) -> Any: def forward(self, *inputs: Any, **_: Any) -> Any:
dummy_input = torch.tensor([], requires_grad=True) # `apply` calls the `forward` function of the `ActivationCheckpointing` class
if self._checkpoint_activation: # and the `forward` function calls `inputs` on the first model shard.
return ActivationCheckpointing.apply(*inputs, dummy_input, self) # Please see https://pytorch.org/docs/stable/autograd.html#function for more details.
self._activations = [] # We need the second param to be a dummy input to enable the
for index in range(-1, len(self.model_slices)): # backward pass to be triggered for integer inputs.
if index >= 0: return ActivationCheckpointing.apply(*inputs, torch.tensor([], requires_grad=True), self)
# TODO(anj-s): This might be a redundant call since we have the previous
# activation on the device already.
self._activations[index] = tuple([a.cuda() for a in list(self._activations[index])])
inputs = self._activations[index]
inputs = self.model_slices[index](*inputs)
# Call the custom autograd hooks (discard/load slices FW and BW)
inputs = ShardSyncLayer.apply(inputs, index, self.model_slices, self)
self._activations.append(inputs)
if index >= 0:
self._activations[index] = tuple([a.cpu() for a in list(self._activations[index])])
# We don't move the last activation/output since the target is present
# on the device.
# TODO(anj-s): It is now a requirement that the target tensors be placed on the
# device.
result = self._activations[-1]
return result[0] if len(result) == 1 else result
...@@ -32,7 +32,7 @@ def test_single_run(): ...@@ -32,7 +32,7 @@ def test_single_run():
device, offload_device = _init() device, offload_device = _init()
model = _get_model() model = _get_model()
offload_model = OffloadModel(model_cpu=model, device=device, offload_device=offload_device, num_slices=2,) offload_model = OffloadModel(model=model, device=device, offload_device=offload_device, num_slices=2,)
offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)
input = torch.ones(2, 2).to(device) input = torch.ones(2, 2).to(device)
...@@ -102,7 +102,7 @@ def _train_offload_model( ...@@ -102,7 +102,7 @@ def _train_offload_model(
): ):
omodel = copy.deepcopy(model) omodel = copy.deepcopy(model)
offload_model = OffloadModel( offload_model = OffloadModel(
model_cpu=omodel, model=omodel,
device=device, device=device,
offload_device=offload_device, offload_device=offload_device,
num_slices=2, num_slices=2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment