Unverified Commit 7d7edf6d authored by Anupam Bhatnagar's avatar Anupam Bhatnagar Committed by GitHub
Browse files

Setup pre-commit github action and apply pre-commit to all files (#849)

* adding pre-commit files

* applying pre-commit to all files

* adding no-strict-optional argument to mypy in circle ci config

* fix typo

* updating python versions

* [skip ci] remove extra args

* adding python 3.9

* [skip ci] set pre-commit version in requirements-dev.txt

* set CACHE_VERSION

* move linters from circleci to github actions

* update python version

* update python version in benchmarks_2

* moving to python 3.9.7
parent 6f3931a4
......@@ -16,9 +16,9 @@ DEBUG = False
def _next_power_of_2_or_max(n: int, max_n: int) -> int:
""" Return the smallest power of 2 greater than or equal to n, with a limit.
"""Return the smallest power of 2 greater than or equal to n, with a limit.
Useful when used in splitting a tensor into chunks with power-of-2 sizes.
Useful when used in splitting a tensor into chunks with power-of-2 sizes.
"""
# special case, just split to 1 element chunks.
if n == 0:
......@@ -50,7 +50,7 @@ def _reshape_inputs(input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Te
def get_data(
shape: Tuple[Tuple[int, int], Tuple[int, int]], dtype: torch.dtype = torch.float16, device: str = "cuda"
) -> Tuple[torch.Tensor, nn.Parameter, torch.Tensor]:
""" Utility function for getting some tensors for testing and benchmarking."""
"""Utility function for getting some tensors for testing and benchmarking."""
(tokens, d1), (d2, vocabs) = shape
assert d1 == d2
input = torch.rand(tokens, d1, device=device, dtype=dtype).requires_grad_(True)
......@@ -66,7 +66,7 @@ def get_data(
class BaselineSoftmax(nn.Module):
""" Baseline softmax that does an output linear projection and a softmax.
"""Baseline softmax that does an output linear projection and a softmax.
This is intended to be used with an embedding layer with shared weights.
......@@ -94,7 +94,7 @@ class BaselineSoftmax(nn.Module):
self.log_softmax = log_softmax
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore
""" Forward function that computes softmax output with the input and target."""
"""Forward function that computes softmax output with the input and target."""
assert isinstance(input, torch.Tensor)
assert isinstance(target, torch.Tensor)
input, target = _reshape_inputs(input, target)
......@@ -111,12 +111,12 @@ class BaselineSoftmax(nn.Module):
class BaselineSoftmaxNllLoss(BaselineSoftmax):
""" Baseline that does an output projection, a softmax & a NLL loss (cross-entropy).
"""Baseline that does an output projection, a softmax & a NLL loss (cross-entropy).
See BaselineSoftmax above. Constructor is the same. Only difference is in the
forward function.
See BaselineSoftmax above. Constructor is the same. Only difference is in the
forward function.
This class is used for testing and benchmarking.
This class is used for testing and benchmarking.
"""
def __init__(self, proj_weight: nn.Parameter, tile_factor: int = 0, log_softmax: bool = True):
......@@ -177,7 +177,7 @@ class GetMaxFunction(torch.autograd.Function):
def backward(ctx: Any, *args: Any) -> Any:
"""Recompute the forward max and backward grad.
Accumulate the grad to the right split of the full grad.
Accumulate the grad to the right split of the full grad.
"""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG max bwd")
......@@ -248,7 +248,7 @@ class GetSumFunction(torch.autograd.Function):
def backward(ctx: Any, *args: Any) -> Any:
"""Recompute the forward sum and backward grad.
Accumulate the grad to the right split of the full grad.
Accumulate the grad to the right split of the full grad.
"""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG sum bwd")
......@@ -333,9 +333,7 @@ class BackwardTriggerFn(torch.autograd.Function):
"""A backward trigger function."""
@staticmethod
def forward( # type: ignore
ctx: Any, w: torch.Tensor, trigger_tensor: torch.Tensor
) -> torch.Tensor:
def forward(ctx: Any, w: torch.Tensor, trigger_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
"""We take a weight tensor and the trigger as inputs and output the weight directly."""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG trigger fwd")
......@@ -357,24 +355,24 @@ class BackwardTriggerFn(torch.autograd.Function):
class BackwardTrigger(nn.Module):
"""A backward trigger module.
This module takes a parameter as an input and create a linked parameter
from a newly created trigger parameter.
This module takes a parameter as an input and create a linked parameter
from a newly created trigger parameter.
The way to use it in a module's ``__init__'' and ``forward'' functions:
The way to use it in a module's ``__init__'' and ``forward'' functions:
```
def __init__():
...
self.trigger = BackwardTrigger(some_layer.weight)
...
```
def __init__():
...
self.trigger = BackwardTrigger(some_layer.weight)
...
def forward():
w = self.trigger()
... continue to use w ...
```
def forward():
w = self.trigger()
... continue to use w ...
```
As a resule, the trigger's backward hook will be called at the end of
the backward for the module that uses this trigger.
As a resule, the trigger's backward hook will be called at the end of
the backward for the module that uses this trigger.
"""
def __init__(self, linked_param: torch.Tensor):
......@@ -388,7 +386,7 @@ class BackwardTrigger(nn.Module):
class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
""" Fused fc + softmax + nll_loss in a tiled fashion.
"""Fused fc + softmax + nll_loss in a tiled fashion.
This uses much less memory but is quite a bit slower.
......
......@@ -80,7 +80,11 @@ class ModelShard(nn.Module):
"""
def __init__(
self, cpu_model_shard: nn.Module, device: torch.device, offload_device: torch.device, index: int,
self,
cpu_model_shard: nn.Module,
device: torch.device,
offload_device: torch.device,
index: int,
):
super().__init__()
self.model_shard = cpu_model_shard
......@@ -138,22 +142,22 @@ class ModelShard(nn.Module):
class OffloadFunction(torch.autograd.Function):
"""
This Function enables checkpointing of intermediate activations at
shard boundaries by overriding the forward and backward pass of the nn.Module.
This Function enables checkpointing of intermediate activations at
shard boundaries by overriding the forward and backward pass of the nn.Module.
- In the FW pass, it drops parameters in the previous shard and
loads parameters for the next shard. No graph is constructed in the FW pass.
This enables us to offload intermediate activations present at the shard
boundaries.
- In the FW pass, it drops parameters in the previous shard and
loads parameters for the next shard. No graph is constructed in the FW pass.
This enables us to offload intermediate activations present at the shard
boundaries.
- In the BW pass, it does the reverse. We run the forward pass using the
saved intermediate activations and calculate gradients as needed.
The trade-off is latency vs memory when using activation checkpointing.
- In the BW pass, it does the reverse. We run the forward pass using the
saved intermediate activations and calculate gradients as needed.
The trade-off is latency vs memory when using activation checkpointing.
- Follows heavily from https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint.
- Follows heavily from https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint.
NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
"""
NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
"""
@staticmethod
@_conditional_amp_fwd_decorator # type: ignore
......@@ -303,14 +307,14 @@ class OffloadFunction(torch.autograd.Function):
class ShardSyncLayer(torch.autograd.Function):
"""
The shard sync layer is a synchronization point between model shards.
- In the forward pass, it drops parameters in the previous shard and
loads parameters for the next shard.
- In the backward pass, it does the reverse.
It does not change or create any outputs at all, instead it just
forwards the input as the output.
NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
"""
The shard sync layer is a synchronization point between model shards.
- In the forward pass, it drops parameters in the previous shard and
loads parameters for the next shard.
- In the backward pass, it does the reverse.
It does not change or create any outputs at all, instead it just
forwards the input as the output.
NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
"""
@staticmethod
@_conditional_amp_fwd_decorator # type: ignore
......@@ -457,17 +461,25 @@ class OffloadModel(nn.Module):
# This is already sharded using the auto shard functinality.
for i, m in enumerate(model):
self.model_slices.append(
ModelShard(cpu_model_shard=m, device=device, offload_device=offload_device, index=i,)
ModelShard(
cpu_model_shard=m,
device=device,
offload_device=offload_device,
index=i,
)
)
else:
# Slice the model into roughly equivalent sequential shards.
splits = _split(model, num_slices)
splits = _split(model, num_slices) # type: ignore
for i, split in enumerate(splits):
# Add one model handling this slice
self.model_slices.append(
ModelShard(
cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device, index=i,
cpu_model_shard=nn.Sequential(*split),
device=device,
offload_device=offload_device,
index=i,
)
)
......
......@@ -72,7 +72,7 @@ def read(input_tensor: torch.Tensor, filename: str, file_offset_bytes: int = 0)
class StorageState(Enum):
"""
Simple enum to indicate whether the tensor handle is pointing
to data on disk or memory. This is useful for asserting on
to data on disk or memory. This is useful for asserting on
whether the tensor is available for operations or if it needs
to be moved from disk to CPU or device.
"""
......
......@@ -200,6 +200,8 @@ class DynamicLossScaler(object):
def state_dict(self) -> Optional[Dict[str, float]]:
if self.loss_scale is not None:
return {"loss_scale": self.loss_scale}
else:
return None
def load_state_dict(self, state_dict: Dict[str, float]) -> None:
if "loss_scale" in state_dict:
......
......@@ -35,7 +35,8 @@ class TraceForwardEvent(NamedTuple):
@classmethod
def from_dict(cls, serialized: Dict[str, Any]) -> "TraceForwardEvent":
return TraceForwardEvent(
memory_diff=serialized["memory_diff"], memory_activations=serialized["memory_activations"],
memory_diff=serialized["memory_diff"],
memory_activations=serialized["memory_activations"],
)
......@@ -410,7 +411,8 @@ class LayerwiseMemoryTracker:
all_gathered=self._last_all_gather_memory,
cumul_all_gathered=sum(self._cumul_all_gather_memory),
event=TraceForwardEvent(
memory_diff=allocated - self._memory_pre_forward, memory_activations=activations,
memory_diff=allocated - self._memory_pre_forward,
memory_activations=activations,
),
)
)
......@@ -593,7 +595,9 @@ def suggest_checkpoint_location(
# Then map it back to module names
return SuggestedCheckpoints(
max_memory=max_memory, split_modules=[modules[i] for i in reset_indices], all_modules=modules,
max_memory=max_memory,
split_modules=[modules[i] for i in reset_indices],
all_modules=modules,
)
......@@ -609,7 +613,9 @@ def _assert_visualisation_library_installed() -> None:
def compare_memory_traces_in_plot(
memory_traces_by_job: Dict[str, List[LayerMemoryTrace]], figsize: Tuple[int, int] = (16, 20), capture: bool = False,
memory_traces_by_job: Dict[str, List[LayerMemoryTrace]],
figsize: Tuple[int, int] = (16, 20),
capture: bool = False,
) -> Optional[Any]:
"""
Create a plot of the memory allocation over time during the forward/backward
......@@ -684,7 +690,7 @@ class _MemoryGraphCreator:
ax.plot(x, y_forward, x, y_backward, label=job_name)
max_index = np.argmax(allocated_memory)
max_trace = memory_traces[max_index] # type: ignore
max_trace = memory_traces[max_index]
max_module = ".".join([n for n in max_trace.module_name.split(".") if not n.startswith("_")])
max_phase = "fwd" if max_trace.is_forward else "bwd"
ax.set_ylim([None, max_trace.allocated * 1.1])
......@@ -722,7 +728,7 @@ class _MemoryGraphCreator:
# Adding the name of the layer with max cumulative all_gathered memory
max_index = np.argmax(cumul_gathered_memory)
max_trace = memory_traces[max_index] # type: ignore
max_trace = memory_traces[max_index]
max_module = ".".join([n for n in max_trace.module_name.split(".") if not n.startswith("_")])
ax.set_ylim([None, max_trace.cumul_all_gathered * 1.1])
x_text, y_text = max(0, max_index * 0.8), max_trace.cumul_all_gathered * 1.04 # type: ignore
......
......@@ -95,7 +95,10 @@ def is_recomputing() -> bool:
return thread_local.is_recomputing
def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False,) -> nn.Module:
def checkpoint_wrapper(
module: nn.Module,
offload_to_cpu: bool = False,
) -> nn.Module:
"""
A friendlier wrapper for performing activation checkpointing.
......
......@@ -448,7 +448,7 @@ class FullyShardedDataParallel(nn.Module):
return self._fsdp_wrapped_module
def append_shared_param(self, p: Parameter) -> None:
""" Add a param that's already owned by another FSDP wrapper.
"""Add a param that's already owned by another FSDP wrapper.
.. warning:: This is experimental!
......@@ -1248,8 +1248,8 @@ class FullyShardedDataParallel(nn.Module):
m._reducer = self._reducer
def _setup_output_hook_list(self) -> None:
""" set up a list to avoid registering pre-backward hooks
incorrectly.
"""set up a list to avoid registering pre-backward hooks
incorrectly.
"""
assert self._is_root, "This should only be called on the root"
self._output_pre_backward_hook_registered = []
......
......@@ -31,7 +31,7 @@ def _trainable(param: torch.Tensor) -> bool:
class ShardedDataParallel(nn.Module):
""" Wrap the model, and reduce the gradients to the right rank during the backward pass.
"""Wrap the model, and reduce the gradients to the right rank during the backward pass.
- the partition is given by the sharded optimizer
- wrap the base model with a model which knows where to reduce each gradient
......@@ -224,7 +224,10 @@ class ShardedDataParallel(nn.Module):
return self.module(*inputs, **kwargs)
def to( # type: ignore
self, device: Optional[torch.device], dtype: Optional[torch.dtype] = None, non_blocking: bool = False,
self,
device: Optional[torch.device],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> "ShardedDataParallel":
"""
Moves and/or casts the parameters and buffers.
......@@ -273,7 +276,7 @@ class ShardedDataParallel(nn.Module):
self.refresh_trainable()
def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """
"""If the module trainability has changed, update all the assumptions"""
# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
if functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False):
......@@ -600,8 +603,8 @@ class ShardedDataParallel(nn.Module):
def _consume_work_handles(self) -> None:
"""Consume all the futures which are tied to this optimizer's buckets.
We start from the first/older ones, since they are the most likely to be ready and non-blocking
"""
We start from the first/older ones, since they are the most likely to be ready and non-blocking
"""
while len(self._work_handles) > 0:
work_handle = self._work_handles.popleft()
......@@ -628,7 +631,10 @@ class ShardedDataParallel(nn.Module):
self._work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self._process_group, async_op=True,
tensor=bucket.buffer,
dst=bucket.destination,
group=self._process_group,
async_op=True,
),
callback=None,
)
......
......@@ -37,12 +37,12 @@ if TYPE_CHECKING:
class FlatParameter(nn.Parameter):
""" A parameter that is initialized from a list of parameters and can be
turned into a list of views as needed.
"""A parameter that is initialized from a list of parameters and can be
turned into a list of views as needed.
"""
def __new__(cls, params: Sequence[nn.Parameter], requires_grad: bool = True) -> "FlatParameter":
""" Make an object using the parent's __new__ function. """
"""Make an object using the parent's __new__ function."""
# A empty of non-list input doesn't make sense.
if not isinstance(params, (list, tuple)) or len(params) == 0:
......@@ -66,7 +66,7 @@ class FlatParameter(nn.Parameter):
return super(FlatParameter, cls).__new__(cls, data, requires_grad=requires_grad)
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
""" Initialize the _param_numels and _param_shapes lists. """
"""Initialize the _param_numels and _param_shapes lists."""
self._param_numels = [p.numel() for p in params]
assert self.numel() <= sum(
self._param_numels
......@@ -78,7 +78,7 @@ class FlatParameter(nn.Parameter):
self._shared_param_infos: List[Tuple[str, str, nn.Module, str, nn.Module, str]] = []
def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Tensor]:
""" Return a generator of views that map to the original parameters. """
"""Return a generator of views that map to the original parameters."""
# Note, self.data could be sharded, so its numel is <= to the sum.
assert self.data.numel() <= sum(
self._param_numels
......@@ -96,14 +96,14 @@ class FlatParameter(nn.Parameter):
return names, self._param_shapes, self._param_numels
def __setstate__(self, state: Tuple[Any, Any, Any, Any]) -> None:
""" Use by pickle to set the internal states. """
"""Use by pickle to set the internal states."""
(self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos) = state
assert self.numel() <= sum(
self._param_numels
), f"Incorrect pickling {self.numel()} vs. {sum(self._param_numels)}"
def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any]:
""" Support pickling between ranks. """
"""Support pickling between ranks."""
return (
FlatParameter, # Callable
# Args to the callable above
......@@ -228,15 +228,15 @@ class FlattenParamsWrapper(nn.Module):
@property
def module(self) -> Any:
""" Support fpw.module in case we are immitating DDP, which has .module
property to the underlying module.
"""Support fpw.module in case we are immitating DDP, which has .module
property to the underlying module.
"""
return self._fpw_module
@property
def flat_param(self) -> nn.Parameter:
""" We used to support only a single flat_param. This allows us to
be backward compatible.
"""We used to support only a single flat_param. This allows us to
be backward compatible.
"""
assert len(self.flat_params) == 1, "Incorrect access to flat_param"
return self.flat_params[0]
......@@ -246,7 +246,7 @@ class FlattenParamsWrapper(nn.Module):
) -> Tuple[
List[nn.Parameter], List[Tuple[str, nn.Module, str]], List[Tuple[str, str, nn.Module, str, nn.Module, str]]
]:
""" Build metadata for need-to-be-flatten parameters and returns a list
"""Build metadata for need-to-be-flatten parameters and returns a list
contains the need-to-be-flatten parameters.
This also returns param_infos and shared_param_infos, which
......@@ -287,8 +287,8 @@ class FlattenParamsWrapper(nn.Module):
return chain(*[p._shared_param_infos for p in self.flat_params])
def _flatten_params(self, flat_params: List[FlatParameter]) -> None:
""" Flatten the managed parameters and replaced the original
attributes with views to the flat params.
"""Flatten the managed parameters and replaced the original
attributes with views to the flat params.
"""
assert not self.is_flattened
self.is_flattened = True
......@@ -309,8 +309,8 @@ class FlattenParamsWrapper(nn.Module):
self._unflatten_params_as_views()
def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = None) -> None:
""" Undo flattening and create separate parameters from the already flattened
self.flat_param or a user supplied external data.
"""Undo flattening and create separate parameters from the already flattened
self.flat_param or a user supplied external data.
"""
assert self.is_flattened or external_data is not None
self.is_flattened = False
......@@ -336,8 +336,8 @@ class FlattenParamsWrapper(nn.Module):
self.flat_params = []
def _unflatten_params_as_views(self) -> None:
""" Unlike ``_unflatten_params``, this function unflatten into views and keep
self.flat_param unchanged.
"""Unlike ``_unflatten_params``, this function unflatten into views and keep
self.flat_param unchanged.
"""
assert self.is_flattened
ps = self.get_param_views()
......@@ -459,7 +459,7 @@ class FlattenParamsWrapper(nn.Module):
return self.module(*inputs, **kwinputs)
def get_param_views(self, external_data_list: Optional[List[Optional[Tensor]]] = None) -> Iterator[Tensor]:
""" Used to get a generator over all views from a list of external data list. """
"""Used to get a generator over all views from a list of external data list."""
params = self.flat_params
if external_data_list is None:
external_data_list = [None] * len(params)
......
......@@ -120,19 +120,17 @@ class GradBucket(Bucket):
self.callback: Optional[Callable[[Any], None]] = None
def reset_checked_in(self) -> None:
""" Reset the counter of the parameter grads which have been checked in
"""
"""Reset the counter of the parameter grads which have been checked in"""
self.params_checked_in = 0
self.sent = False
@property
def all_checked_in(self) -> bool:
""" Have all the expected gradient check-in happened ?"""
"""Have all the expected gradient check-in happened ?"""
return len(self._params) == self.params_checked_in
def can_add_grad_view(self, param: torch.Tensor) -> bool:
""" Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?
"""
"""Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?"""
return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids
def to( # type: ignore
......
......@@ -117,7 +117,11 @@ class Top2Gate(torch.nn.Module):
wg: torch.nn.Linear
def __init__(self, model_dim: int, num_experts: int,) -> None:
def __init__(
self,
model_dim: int,
num_experts: int,
) -> None:
super().__init__()
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
......
......@@ -222,7 +222,10 @@ class AsyncPipe(Module):
)
def instantiate_partition(
self, module: Union[nn.Sequential, List[LazyModule]], balance: List[int], group: torch.distributed.ProcessGroup,
self,
module: Union[nn.Sequential, List[LazyModule]],
balance: List[int],
group: torch.distributed.ProcessGroup,
) -> List[ModuleWrapper]:
layers: NamedModules = OrderedDict()
......
......@@ -64,7 +64,13 @@ class AsyncPipeline:
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))]
rank = self.group.rank()
event_loop = AsyncEventLoop(self.partitions, self.group, self.transport, self.training, self.checkpoint_stop,)
event_loop = AsyncEventLoop(
self.partitions,
self.group,
self.transport,
self.training,
self.checkpoint_stop,
)
if rank == 0 and not self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event head")
event_loop.event_loop_head(batches, skip_trackers, event)
......
......@@ -169,7 +169,10 @@ class AsyncRecvOperator(torch.autograd.Function):
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]:
def backward(
ctx,
*grad: Tensor,
) -> Tuple[Optional[Tensor], ...]:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
body = AsyncMessageBody(
......@@ -177,7 +180,11 @@ class AsyncRecvOperator(torch.autograd.Function):
)
ctx.transport.send_message(
PipeMessage(
this_rank, ranks[ctx.args.source.stage], queue_name=ctx.queue_name, args=body, tensors=tuple(grad),
this_rank,
ranks[ctx.args.source.stage],
queue_name=ctx.queue_name,
args=body,
tensors=tuple(grad),
),
sync=True,
)
......@@ -242,7 +249,12 @@ class AsyncEventLoop:
to the next stage in the pipeline if needed."""
task = create_task(
self.checkpoint_stop, batch.index, self.group.rank(), batch, partition.module, skip_trackers,
self.checkpoint_stop,
batch.index,
self.group.rank(),
batch,
partition.module,
skip_trackers,
)
result = task.compute()
task.finalize(result)
......@@ -267,7 +279,7 @@ class AsyncEventLoop:
# All batches saved in `activations` are generated by AutogradWithoutActivations,
# so we store the gradients in `grad_from_pipeline` so it will be used
# during the backward pass
batch.tensor.grad_fn.grad_from_pipeline = tuple(recvd_grads.tensors) # type: ignore
batch.tensor.grad_fn.grad_from_pipeline = tuple(recvd_grads.tensors)
batch.tensor.backward(retain_graph=True)
def run_invocations_on_batch(
......
......@@ -37,7 +37,10 @@ Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
def layerwise_sandbox(module: nn.Sequential, device: torch.device,) -> Generator[nn.Module, None, None]:
def layerwise_sandbox(
module: nn.Sequential,
device: torch.device,
) -> Generator[nn.Module, None, None]:
"""Copies layers for ease to profile. It doesn't modify the given
module.
"""
......@@ -54,7 +57,12 @@ def detach(batch: Batch) -> None:
batch[i] = x.detach().requires_grad_(x.requires_grad)
def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float, device: torch.device,) -> List[int]:
def profile_times(
module: nn.Sequential,
sample: TensorOrTensors,
timeout: float,
device: torch.device,
) -> List[int]:
"""Profiles elapsed times per layer."""
if any(p.grad is not None for p in module.parameters()):
raise ValueError("some parameter already has gradient")
......@@ -95,7 +103,11 @@ def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float
def profile_sizes(
module: nn.Sequential, input: TensorOrTensors, chunks: int, param_scale: float, device: torch.device,
module: nn.Sequential,
input: TensorOrTensors,
chunks: int,
param_scale: float,
device: torch.device,
) -> List[int]:
"""Profiles CUDA memory usage per layer."""
if device.type != "cuda":
......
......@@ -197,7 +197,10 @@ class Context:
pass
def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None:
def save_rng_states(
device: torch.device,
rng_states: Deque[RNGStates],
) -> None:
""":meth:`Checkpoint.forward` captures the current PyTorch's random number
generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.
......@@ -217,7 +220,10 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None
@contextmanager
def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]:
def restore_rng_states(
device: torch.device,
rng_states: Deque[RNGStates],
) -> Generator[None, None, None]:
""":meth:`Recompute.backward` restores the random number generator states
captured by :func:`save_rng_states` within its context.
......@@ -264,7 +270,10 @@ class Checkpoint(torch.autograd.Function):
return output
@staticmethod
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover
def backward(
ctx: Context,
*grad_output: Tensor,
) -> Tuple[Optional[Tensor], ...]: # pragma: no cover
output, input_leaf = ctx.recomputed.pop()
if isinstance(output, tuple):
......
......@@ -45,7 +45,12 @@ class Copy(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors:
def forward(
ctx: Context,
prev_stream: AbstractStream,
next_stream: AbstractStream,
*input: Tensor,
) -> Tensors:
ctx.prev_stream = prev_stream
ctx.next_stream = next_stream
......@@ -66,7 +71,10 @@ class Copy(torch.autograd.Function):
return tuple(output)
@staticmethod
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]:
def backward(
ctx: Context,
*grad_output: Tensor,
) -> Tuple[Optional[Tensor], ...]:
prev_stream = ctx.prev_stream
next_stream = ctx.next_stream
......@@ -98,7 +106,12 @@ class Wait(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors:
def forward(
ctx: Context,
prev_stream: AbstractStream,
next_stream: AbstractStream,
*input: Tensor,
) -> Tensors:
ctx.prev_stream = prev_stream
ctx.next_stream = next_stream
......@@ -107,7 +120,10 @@ class Wait(torch.autograd.Function):
return tuple(x.detach() for x in input)
@staticmethod
def backward(ctx: Context, *grad_input: Tensor,) -> Tuple[Optional[Tensor], ...]:
def backward(
ctx: Context,
*grad_input: Tensor,
) -> Tuple[Optional[Tensor], ...]:
prev_stream = ctx.prev_stream
next_stream = ctx.next_stream
......
......@@ -106,7 +106,9 @@ class BalanceError(ValueError):
def split_module(
module: nn.Sequential, balance: Iterable[int], devices: List[torch.device],
module: nn.Sequential,
balance: Iterable[int],
devices: List[torch.device],
) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:
"""Splits a module into multiple partitions.
......@@ -350,12 +352,12 @@ class Pipe(Module):
raise MOVING_DENIED
def to(self, *args: Any, **kwargs: Any) -> "Pipe":
""" Deny these usages:
- to(device[, dtype, non_blocking])
- to(tensor[, non_blocking])
"""Deny these usages:
- to(device[, dtype, non_blocking])
- to(tensor[, non_blocking])
But allow this:
- to(dtype[, non_blocking])"""
But allow this:
- to(dtype[, non_blocking])"""
if "device" in kwargs or "tensor" in kwargs:
raise MOVING_DENIED
......
......@@ -130,7 +130,10 @@ class Pipeline:
self.compute(batches, schedule, skip_trackers)
def fence(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
self,
batches: List[Batch],
schedule: List[Tuple[int, int]],
skip_trackers: List[SkipTrackerThroughPotals],
) -> None:
"""Copies micro-batches after computation for the previous
micro-batches.
......@@ -155,7 +158,10 @@ class Pipeline:
copy(batches[i], prev_stream, next_stream)
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
self,
batches: List[Batch],
schedule: List[Tuple[int, int]],
skip_trackers: List[SkipTrackerThroughPotals],
) -> None:
"""Runs tasks with synchronization to copy streams."""
partitions = self.partitions
......
......@@ -39,7 +39,11 @@ class SkipLayout:
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_src_partition: List[List[Tuple[int, Namespace, str]]]
def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None:
def __init__(
self,
num_partitions: int,
skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],
) -> None:
# The skip routes are already indexed by 'ns, name'.
self.by_ns_name = skip_routes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment