Unverified Commit 7d7edf6d authored by Anupam Bhatnagar's avatar Anupam Bhatnagar Committed by GitHub
Browse files

Setup pre-commit github action and apply pre-commit to all files (#849)

* adding pre-commit files

* applying pre-commit to all files

* adding no-strict-optional argument to mypy in circle ci config

* fix typo

* updating python versions

* [skip ci] remove extra args

* adding python 3.9

* [skip ci] set pre-commit version in requirements-dev.txt

* set CACHE_VERSION

* move linters from circleci to github actions

* update python version

* update python version in benchmarks_2

* moving to python 3.9.7
parent 6f3931a4
...@@ -16,9 +16,9 @@ DEBUG = False ...@@ -16,9 +16,9 @@ DEBUG = False
def _next_power_of_2_or_max(n: int, max_n: int) -> int: def _next_power_of_2_or_max(n: int, max_n: int) -> int:
""" Return the smallest power of 2 greater than or equal to n, with a limit. """Return the smallest power of 2 greater than or equal to n, with a limit.
Useful when used in splitting a tensor into chunks with power-of-2 sizes. Useful when used in splitting a tensor into chunks with power-of-2 sizes.
""" """
# special case, just split to 1 element chunks. # special case, just split to 1 element chunks.
if n == 0: if n == 0:
...@@ -50,7 +50,7 @@ def _reshape_inputs(input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Te ...@@ -50,7 +50,7 @@ def _reshape_inputs(input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Te
def get_data( def get_data(
shape: Tuple[Tuple[int, int], Tuple[int, int]], dtype: torch.dtype = torch.float16, device: str = "cuda" shape: Tuple[Tuple[int, int], Tuple[int, int]], dtype: torch.dtype = torch.float16, device: str = "cuda"
) -> Tuple[torch.Tensor, nn.Parameter, torch.Tensor]: ) -> Tuple[torch.Tensor, nn.Parameter, torch.Tensor]:
""" Utility function for getting some tensors for testing and benchmarking.""" """Utility function for getting some tensors for testing and benchmarking."""
(tokens, d1), (d2, vocabs) = shape (tokens, d1), (d2, vocabs) = shape
assert d1 == d2 assert d1 == d2
input = torch.rand(tokens, d1, device=device, dtype=dtype).requires_grad_(True) input = torch.rand(tokens, d1, device=device, dtype=dtype).requires_grad_(True)
...@@ -66,7 +66,7 @@ def get_data( ...@@ -66,7 +66,7 @@ def get_data(
class BaselineSoftmax(nn.Module): class BaselineSoftmax(nn.Module):
""" Baseline softmax that does an output linear projection and a softmax. """Baseline softmax that does an output linear projection and a softmax.
This is intended to be used with an embedding layer with shared weights. This is intended to be used with an embedding layer with shared weights.
...@@ -94,7 +94,7 @@ class BaselineSoftmax(nn.Module): ...@@ -94,7 +94,7 @@ class BaselineSoftmax(nn.Module):
self.log_softmax = log_softmax self.log_softmax = log_softmax
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore
""" Forward function that computes softmax output with the input and target.""" """Forward function that computes softmax output with the input and target."""
assert isinstance(input, torch.Tensor) assert isinstance(input, torch.Tensor)
assert isinstance(target, torch.Tensor) assert isinstance(target, torch.Tensor)
input, target = _reshape_inputs(input, target) input, target = _reshape_inputs(input, target)
...@@ -111,12 +111,12 @@ class BaselineSoftmax(nn.Module): ...@@ -111,12 +111,12 @@ class BaselineSoftmax(nn.Module):
class BaselineSoftmaxNllLoss(BaselineSoftmax): class BaselineSoftmaxNllLoss(BaselineSoftmax):
""" Baseline that does an output projection, a softmax & a NLL loss (cross-entropy). """Baseline that does an output projection, a softmax & a NLL loss (cross-entropy).
See BaselineSoftmax above. Constructor is the same. Only difference is in the See BaselineSoftmax above. Constructor is the same. Only difference is in the
forward function. forward function.
This class is used for testing and benchmarking. This class is used for testing and benchmarking.
""" """
def __init__(self, proj_weight: nn.Parameter, tile_factor: int = 0, log_softmax: bool = True): def __init__(self, proj_weight: nn.Parameter, tile_factor: int = 0, log_softmax: bool = True):
...@@ -177,7 +177,7 @@ class GetMaxFunction(torch.autograd.Function): ...@@ -177,7 +177,7 @@ class GetMaxFunction(torch.autograd.Function):
def backward(ctx: Any, *args: Any) -> Any: def backward(ctx: Any, *args: Any) -> Any:
"""Recompute the forward max and backward grad. """Recompute the forward max and backward grad.
Accumulate the grad to the right split of the full grad. Accumulate the grad to the right split of the full grad.
""" """
if DEBUG and dist.is_initialized() and dist.get_rank() == 0: if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG max bwd") print("DEBUG max bwd")
...@@ -248,7 +248,7 @@ class GetSumFunction(torch.autograd.Function): ...@@ -248,7 +248,7 @@ class GetSumFunction(torch.autograd.Function):
def backward(ctx: Any, *args: Any) -> Any: def backward(ctx: Any, *args: Any) -> Any:
"""Recompute the forward sum and backward grad. """Recompute the forward sum and backward grad.
Accumulate the grad to the right split of the full grad. Accumulate the grad to the right split of the full grad.
""" """
if DEBUG and dist.is_initialized() and dist.get_rank() == 0: if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG sum bwd") print("DEBUG sum bwd")
...@@ -333,9 +333,7 @@ class BackwardTriggerFn(torch.autograd.Function): ...@@ -333,9 +333,7 @@ class BackwardTriggerFn(torch.autograd.Function):
"""A backward trigger function.""" """A backward trigger function."""
@staticmethod @staticmethod
def forward( # type: ignore def forward(ctx: Any, w: torch.Tensor, trigger_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
ctx: Any, w: torch.Tensor, trigger_tensor: torch.Tensor
) -> torch.Tensor:
"""We take a weight tensor and the trigger as inputs and output the weight directly.""" """We take a weight tensor and the trigger as inputs and output the weight directly."""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0: if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG trigger fwd") print("DEBUG trigger fwd")
...@@ -357,24 +355,24 @@ class BackwardTriggerFn(torch.autograd.Function): ...@@ -357,24 +355,24 @@ class BackwardTriggerFn(torch.autograd.Function):
class BackwardTrigger(nn.Module): class BackwardTrigger(nn.Module):
"""A backward trigger module. """A backward trigger module.
This module takes a parameter as an input and create a linked parameter This module takes a parameter as an input and create a linked parameter
from a newly created trigger parameter. from a newly created trigger parameter.
The way to use it in a module's ``__init__'' and ``forward'' functions: The way to use it in a module's ``__init__'' and ``forward'' functions:
``` ```
def __init__(): def __init__():
... ...
self.trigger = BackwardTrigger(some_layer.weight) self.trigger = BackwardTrigger(some_layer.weight)
... ...
def forward(): def forward():
w = self.trigger() w = self.trigger()
... continue to use w ... ... continue to use w ...
``` ```
As a resule, the trigger's backward hook will be called at the end of As a resule, the trigger's backward hook will be called at the end of
the backward for the module that uses this trigger. the backward for the module that uses this trigger.
""" """
def __init__(self, linked_param: torch.Tensor): def __init__(self, linked_param: torch.Tensor):
...@@ -388,7 +386,7 @@ class BackwardTrigger(nn.Module): ...@@ -388,7 +386,7 @@ class BackwardTrigger(nn.Module):
class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
""" Fused fc + softmax + nll_loss in a tiled fashion. """Fused fc + softmax + nll_loss in a tiled fashion.
This uses much less memory but is quite a bit slower. This uses much less memory but is quite a bit slower.
......
...@@ -80,7 +80,11 @@ class ModelShard(nn.Module): ...@@ -80,7 +80,11 @@ class ModelShard(nn.Module):
""" """
def __init__( def __init__(
self, cpu_model_shard: nn.Module, device: torch.device, offload_device: torch.device, index: int, self,
cpu_model_shard: nn.Module,
device: torch.device,
offload_device: torch.device,
index: int,
): ):
super().__init__() super().__init__()
self.model_shard = cpu_model_shard self.model_shard = cpu_model_shard
...@@ -138,22 +142,22 @@ class ModelShard(nn.Module): ...@@ -138,22 +142,22 @@ class ModelShard(nn.Module):
class OffloadFunction(torch.autograd.Function): class OffloadFunction(torch.autograd.Function):
""" """
This Function enables checkpointing of intermediate activations at This Function enables checkpointing of intermediate activations at
shard boundaries by overriding the forward and backward pass of the nn.Module. shard boundaries by overriding the forward and backward pass of the nn.Module.
- In the FW pass, it drops parameters in the previous shard and - In the FW pass, it drops parameters in the previous shard and
loads parameters for the next shard. No graph is constructed in the FW pass. loads parameters for the next shard. No graph is constructed in the FW pass.
This enables us to offload intermediate activations present at the shard This enables us to offload intermediate activations present at the shard
boundaries. boundaries.
- In the BW pass, it does the reverse. We run the forward pass using the - In the BW pass, it does the reverse. We run the forward pass using the
saved intermediate activations and calculate gradients as needed. saved intermediate activations and calculate gradients as needed.
The trade-off is latency vs memory when using activation checkpointing. The trade-off is latency vs memory when using activation checkpointing.
- Follows heavily from https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint. - Follows heavily from https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint.
NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
""" """
@staticmethod @staticmethod
@_conditional_amp_fwd_decorator # type: ignore @_conditional_amp_fwd_decorator # type: ignore
...@@ -303,14 +307,14 @@ class OffloadFunction(torch.autograd.Function): ...@@ -303,14 +307,14 @@ class OffloadFunction(torch.autograd.Function):
class ShardSyncLayer(torch.autograd.Function): class ShardSyncLayer(torch.autograd.Function):
""" """
The shard sync layer is a synchronization point between model shards. The shard sync layer is a synchronization point between model shards.
- In the forward pass, it drops parameters in the previous shard and - In the forward pass, it drops parameters in the previous shard and
loads parameters for the next shard. loads parameters for the next shard.
- In the backward pass, it does the reverse. - In the backward pass, it does the reverse.
It does not change or create any outputs at all, instead it just It does not change or create any outputs at all, instead it just
forwards the input as the output. forwards the input as the output.
NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
""" """
@staticmethod @staticmethod
@_conditional_amp_fwd_decorator # type: ignore @_conditional_amp_fwd_decorator # type: ignore
...@@ -457,17 +461,25 @@ class OffloadModel(nn.Module): ...@@ -457,17 +461,25 @@ class OffloadModel(nn.Module):
# This is already sharded using the auto shard functinality. # This is already sharded using the auto shard functinality.
for i, m in enumerate(model): for i, m in enumerate(model):
self.model_slices.append( self.model_slices.append(
ModelShard(cpu_model_shard=m, device=device, offload_device=offload_device, index=i,) ModelShard(
cpu_model_shard=m,
device=device,
offload_device=offload_device,
index=i,
)
) )
else: else:
# Slice the model into roughly equivalent sequential shards. # Slice the model into roughly equivalent sequential shards.
splits = _split(model, num_slices) splits = _split(model, num_slices) # type: ignore
for i, split in enumerate(splits): for i, split in enumerate(splits):
# Add one model handling this slice # Add one model handling this slice
self.model_slices.append( self.model_slices.append(
ModelShard( ModelShard(
cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device, index=i, cpu_model_shard=nn.Sequential(*split),
device=device,
offload_device=offload_device,
index=i,
) )
) )
......
...@@ -72,7 +72,7 @@ def read(input_tensor: torch.Tensor, filename: str, file_offset_bytes: int = 0) ...@@ -72,7 +72,7 @@ def read(input_tensor: torch.Tensor, filename: str, file_offset_bytes: int = 0)
class StorageState(Enum): class StorageState(Enum):
""" """
Simple enum to indicate whether the tensor handle is pointing Simple enum to indicate whether the tensor handle is pointing
to data on disk or memory. This is useful for asserting on to data on disk or memory. This is useful for asserting on
whether the tensor is available for operations or if it needs whether the tensor is available for operations or if it needs
to be moved from disk to CPU or device. to be moved from disk to CPU or device.
""" """
......
...@@ -200,6 +200,8 @@ class DynamicLossScaler(object): ...@@ -200,6 +200,8 @@ class DynamicLossScaler(object):
def state_dict(self) -> Optional[Dict[str, float]]: def state_dict(self) -> Optional[Dict[str, float]]:
if self.loss_scale is not None: if self.loss_scale is not None:
return {"loss_scale": self.loss_scale} return {"loss_scale": self.loss_scale}
else:
return None
def load_state_dict(self, state_dict: Dict[str, float]) -> None: def load_state_dict(self, state_dict: Dict[str, float]) -> None:
if "loss_scale" in state_dict: if "loss_scale" in state_dict:
......
...@@ -35,7 +35,8 @@ class TraceForwardEvent(NamedTuple): ...@@ -35,7 +35,8 @@ class TraceForwardEvent(NamedTuple):
@classmethod @classmethod
def from_dict(cls, serialized: Dict[str, Any]) -> "TraceForwardEvent": def from_dict(cls, serialized: Dict[str, Any]) -> "TraceForwardEvent":
return TraceForwardEvent( return TraceForwardEvent(
memory_diff=serialized["memory_diff"], memory_activations=serialized["memory_activations"], memory_diff=serialized["memory_diff"],
memory_activations=serialized["memory_activations"],
) )
...@@ -410,7 +411,8 @@ class LayerwiseMemoryTracker: ...@@ -410,7 +411,8 @@ class LayerwiseMemoryTracker:
all_gathered=self._last_all_gather_memory, all_gathered=self._last_all_gather_memory,
cumul_all_gathered=sum(self._cumul_all_gather_memory), cumul_all_gathered=sum(self._cumul_all_gather_memory),
event=TraceForwardEvent( event=TraceForwardEvent(
memory_diff=allocated - self._memory_pre_forward, memory_activations=activations, memory_diff=allocated - self._memory_pre_forward,
memory_activations=activations,
), ),
) )
) )
...@@ -593,7 +595,9 @@ def suggest_checkpoint_location( ...@@ -593,7 +595,9 @@ def suggest_checkpoint_location(
# Then map it back to module names # Then map it back to module names
return SuggestedCheckpoints( return SuggestedCheckpoints(
max_memory=max_memory, split_modules=[modules[i] for i in reset_indices], all_modules=modules, max_memory=max_memory,
split_modules=[modules[i] for i in reset_indices],
all_modules=modules,
) )
...@@ -609,7 +613,9 @@ def _assert_visualisation_library_installed() -> None: ...@@ -609,7 +613,9 @@ def _assert_visualisation_library_installed() -> None:
def compare_memory_traces_in_plot( def compare_memory_traces_in_plot(
memory_traces_by_job: Dict[str, List[LayerMemoryTrace]], figsize: Tuple[int, int] = (16, 20), capture: bool = False, memory_traces_by_job: Dict[str, List[LayerMemoryTrace]],
figsize: Tuple[int, int] = (16, 20),
capture: bool = False,
) -> Optional[Any]: ) -> Optional[Any]:
""" """
Create a plot of the memory allocation over time during the forward/backward Create a plot of the memory allocation over time during the forward/backward
...@@ -684,7 +690,7 @@ class _MemoryGraphCreator: ...@@ -684,7 +690,7 @@ class _MemoryGraphCreator:
ax.plot(x, y_forward, x, y_backward, label=job_name) ax.plot(x, y_forward, x, y_backward, label=job_name)
max_index = np.argmax(allocated_memory) max_index = np.argmax(allocated_memory)
max_trace = memory_traces[max_index] # type: ignore max_trace = memory_traces[max_index]
max_module = ".".join([n for n in max_trace.module_name.split(".") if not n.startswith("_")]) max_module = ".".join([n for n in max_trace.module_name.split(".") if not n.startswith("_")])
max_phase = "fwd" if max_trace.is_forward else "bwd" max_phase = "fwd" if max_trace.is_forward else "bwd"
ax.set_ylim([None, max_trace.allocated * 1.1]) ax.set_ylim([None, max_trace.allocated * 1.1])
...@@ -722,7 +728,7 @@ class _MemoryGraphCreator: ...@@ -722,7 +728,7 @@ class _MemoryGraphCreator:
# Adding the name of the layer with max cumulative all_gathered memory # Adding the name of the layer with max cumulative all_gathered memory
max_index = np.argmax(cumul_gathered_memory) max_index = np.argmax(cumul_gathered_memory)
max_trace = memory_traces[max_index] # type: ignore max_trace = memory_traces[max_index]
max_module = ".".join([n for n in max_trace.module_name.split(".") if not n.startswith("_")]) max_module = ".".join([n for n in max_trace.module_name.split(".") if not n.startswith("_")])
ax.set_ylim([None, max_trace.cumul_all_gathered * 1.1]) ax.set_ylim([None, max_trace.cumul_all_gathered * 1.1])
x_text, y_text = max(0, max_index * 0.8), max_trace.cumul_all_gathered * 1.04 # type: ignore x_text, y_text = max(0, max_index * 0.8), max_trace.cumul_all_gathered * 1.04 # type: ignore
......
...@@ -95,7 +95,10 @@ def is_recomputing() -> bool: ...@@ -95,7 +95,10 @@ def is_recomputing() -> bool:
return thread_local.is_recomputing return thread_local.is_recomputing
def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False,) -> nn.Module: def checkpoint_wrapper(
module: nn.Module,
offload_to_cpu: bool = False,
) -> nn.Module:
""" """
A friendlier wrapper for performing activation checkpointing. A friendlier wrapper for performing activation checkpointing.
......
...@@ -448,7 +448,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -448,7 +448,7 @@ class FullyShardedDataParallel(nn.Module):
return self._fsdp_wrapped_module return self._fsdp_wrapped_module
def append_shared_param(self, p: Parameter) -> None: def append_shared_param(self, p: Parameter) -> None:
""" Add a param that's already owned by another FSDP wrapper. """Add a param that's already owned by another FSDP wrapper.
.. warning:: This is experimental! .. warning:: This is experimental!
...@@ -1248,8 +1248,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1248,8 +1248,8 @@ class FullyShardedDataParallel(nn.Module):
m._reducer = self._reducer m._reducer = self._reducer
def _setup_output_hook_list(self) -> None: def _setup_output_hook_list(self) -> None:
""" set up a list to avoid registering pre-backward hooks """set up a list to avoid registering pre-backward hooks
incorrectly. incorrectly.
""" """
assert self._is_root, "This should only be called on the root" assert self._is_root, "This should only be called on the root"
self._output_pre_backward_hook_registered = [] self._output_pre_backward_hook_registered = []
......
...@@ -31,7 +31,7 @@ def _trainable(param: torch.Tensor) -> bool: ...@@ -31,7 +31,7 @@ def _trainable(param: torch.Tensor) -> bool:
class ShardedDataParallel(nn.Module): class ShardedDataParallel(nn.Module):
""" Wrap the model, and reduce the gradients to the right rank during the backward pass. """Wrap the model, and reduce the gradients to the right rank during the backward pass.
- the partition is given by the sharded optimizer - the partition is given by the sharded optimizer
- wrap the base model with a model which knows where to reduce each gradient - wrap the base model with a model which knows where to reduce each gradient
...@@ -224,7 +224,10 @@ class ShardedDataParallel(nn.Module): ...@@ -224,7 +224,10 @@ class ShardedDataParallel(nn.Module):
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
def to( # type: ignore def to( # type: ignore
self, device: Optional[torch.device], dtype: Optional[torch.dtype] = None, non_blocking: bool = False, self,
device: Optional[torch.device],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> "ShardedDataParallel": ) -> "ShardedDataParallel":
""" """
Moves and/or casts the parameters and buffers. Moves and/or casts the parameters and buffers.
...@@ -273,7 +276,7 @@ class ShardedDataParallel(nn.Module): ...@@ -273,7 +276,7 @@ class ShardedDataParallel(nn.Module):
self.refresh_trainable() self.refresh_trainable()
def refresh_trainable(self) -> None: def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """ """If the module trainability has changed, update all the assumptions"""
# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance) # Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
if functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False): if functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False):
...@@ -600,8 +603,8 @@ class ShardedDataParallel(nn.Module): ...@@ -600,8 +603,8 @@ class ShardedDataParallel(nn.Module):
def _consume_work_handles(self) -> None: def _consume_work_handles(self) -> None:
"""Consume all the futures which are tied to this optimizer's buckets. """Consume all the futures which are tied to this optimizer's buckets.
We start from the first/older ones, since they are the most likely to be ready and non-blocking We start from the first/older ones, since they are the most likely to be ready and non-blocking
""" """
while len(self._work_handles) > 0: while len(self._work_handles) > 0:
work_handle = self._work_handles.popleft() work_handle = self._work_handles.popleft()
...@@ -628,7 +631,10 @@ class ShardedDataParallel(nn.Module): ...@@ -628,7 +631,10 @@ class ShardedDataParallel(nn.Module):
self._work_handles.append( self._work_handles.append(
Workhandle( Workhandle(
handle=dist.reduce( handle=dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self._process_group, async_op=True, tensor=bucket.buffer,
dst=bucket.destination,
group=self._process_group,
async_op=True,
), ),
callback=None, callback=None,
) )
......
...@@ -37,12 +37,12 @@ if TYPE_CHECKING: ...@@ -37,12 +37,12 @@ if TYPE_CHECKING:
class FlatParameter(nn.Parameter): class FlatParameter(nn.Parameter):
""" A parameter that is initialized from a list of parameters and can be """A parameter that is initialized from a list of parameters and can be
turned into a list of views as needed. turned into a list of views as needed.
""" """
def __new__(cls, params: Sequence[nn.Parameter], requires_grad: bool = True) -> "FlatParameter": def __new__(cls, params: Sequence[nn.Parameter], requires_grad: bool = True) -> "FlatParameter":
""" Make an object using the parent's __new__ function. """ """Make an object using the parent's __new__ function."""
# A empty of non-list input doesn't make sense. # A empty of non-list input doesn't make sense.
if not isinstance(params, (list, tuple)) or len(params) == 0: if not isinstance(params, (list, tuple)) or len(params) == 0:
...@@ -66,7 +66,7 @@ class FlatParameter(nn.Parameter): ...@@ -66,7 +66,7 @@ class FlatParameter(nn.Parameter):
return super(FlatParameter, cls).__new__(cls, data, requires_grad=requires_grad) return super(FlatParameter, cls).__new__(cls, data, requires_grad=requires_grad)
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True): def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
""" Initialize the _param_numels and _param_shapes lists. """ """Initialize the _param_numels and _param_shapes lists."""
self._param_numels = [p.numel() for p in params] self._param_numels = [p.numel() for p in params]
assert self.numel() <= sum( assert self.numel() <= sum(
self._param_numels self._param_numels
...@@ -78,7 +78,7 @@ class FlatParameter(nn.Parameter): ...@@ -78,7 +78,7 @@ class FlatParameter(nn.Parameter):
self._shared_param_infos: List[Tuple[str, str, nn.Module, str, nn.Module, str]] = [] self._shared_param_infos: List[Tuple[str, str, nn.Module, str, nn.Module, str]] = []
def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Tensor]: def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Tensor]:
""" Return a generator of views that map to the original parameters. """ """Return a generator of views that map to the original parameters."""
# Note, self.data could be sharded, so its numel is <= to the sum. # Note, self.data could be sharded, so its numel is <= to the sum.
assert self.data.numel() <= sum( assert self.data.numel() <= sum(
self._param_numels self._param_numels
...@@ -96,14 +96,14 @@ class FlatParameter(nn.Parameter): ...@@ -96,14 +96,14 @@ class FlatParameter(nn.Parameter):
return names, self._param_shapes, self._param_numels return names, self._param_shapes, self._param_numels
def __setstate__(self, state: Tuple[Any, Any, Any, Any]) -> None: def __setstate__(self, state: Tuple[Any, Any, Any, Any]) -> None:
""" Use by pickle to set the internal states. """ """Use by pickle to set the internal states."""
(self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos) = state (self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos) = state
assert self.numel() <= sum( assert self.numel() <= sum(
self._param_numels self._param_numels
), f"Incorrect pickling {self.numel()} vs. {sum(self._param_numels)}" ), f"Incorrect pickling {self.numel()} vs. {sum(self._param_numels)}"
def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any]: def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any]:
""" Support pickling between ranks. """ """Support pickling between ranks."""
return ( return (
FlatParameter, # Callable FlatParameter, # Callable
# Args to the callable above # Args to the callable above
...@@ -228,15 +228,15 @@ class FlattenParamsWrapper(nn.Module): ...@@ -228,15 +228,15 @@ class FlattenParamsWrapper(nn.Module):
@property @property
def module(self) -> Any: def module(self) -> Any:
""" Support fpw.module in case we are immitating DDP, which has .module """Support fpw.module in case we are immitating DDP, which has .module
property to the underlying module. property to the underlying module.
""" """
return self._fpw_module return self._fpw_module
@property @property
def flat_param(self) -> nn.Parameter: def flat_param(self) -> nn.Parameter:
""" We used to support only a single flat_param. This allows us to """We used to support only a single flat_param. This allows us to
be backward compatible. be backward compatible.
""" """
assert len(self.flat_params) == 1, "Incorrect access to flat_param" assert len(self.flat_params) == 1, "Incorrect access to flat_param"
return self.flat_params[0] return self.flat_params[0]
...@@ -246,7 +246,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -246,7 +246,7 @@ class FlattenParamsWrapper(nn.Module):
) -> Tuple[ ) -> Tuple[
List[nn.Parameter], List[Tuple[str, nn.Module, str]], List[Tuple[str, str, nn.Module, str, nn.Module, str]] List[nn.Parameter], List[Tuple[str, nn.Module, str]], List[Tuple[str, str, nn.Module, str, nn.Module, str]]
]: ]:
""" Build metadata for need-to-be-flatten parameters and returns a list """Build metadata for need-to-be-flatten parameters and returns a list
contains the need-to-be-flatten parameters. contains the need-to-be-flatten parameters.
This also returns param_infos and shared_param_infos, which This also returns param_infos and shared_param_infos, which
...@@ -287,8 +287,8 @@ class FlattenParamsWrapper(nn.Module): ...@@ -287,8 +287,8 @@ class FlattenParamsWrapper(nn.Module):
return chain(*[p._shared_param_infos for p in self.flat_params]) return chain(*[p._shared_param_infos for p in self.flat_params])
def _flatten_params(self, flat_params: List[FlatParameter]) -> None: def _flatten_params(self, flat_params: List[FlatParameter]) -> None:
""" Flatten the managed parameters and replaced the original """Flatten the managed parameters and replaced the original
attributes with views to the flat params. attributes with views to the flat params.
""" """
assert not self.is_flattened assert not self.is_flattened
self.is_flattened = True self.is_flattened = True
...@@ -309,8 +309,8 @@ class FlattenParamsWrapper(nn.Module): ...@@ -309,8 +309,8 @@ class FlattenParamsWrapper(nn.Module):
self._unflatten_params_as_views() self._unflatten_params_as_views()
def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = None) -> None: def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = None) -> None:
""" Undo flattening and create separate parameters from the already flattened """Undo flattening and create separate parameters from the already flattened
self.flat_param or a user supplied external data. self.flat_param or a user supplied external data.
""" """
assert self.is_flattened or external_data is not None assert self.is_flattened or external_data is not None
self.is_flattened = False self.is_flattened = False
...@@ -336,8 +336,8 @@ class FlattenParamsWrapper(nn.Module): ...@@ -336,8 +336,8 @@ class FlattenParamsWrapper(nn.Module):
self.flat_params = [] self.flat_params = []
def _unflatten_params_as_views(self) -> None: def _unflatten_params_as_views(self) -> None:
""" Unlike ``_unflatten_params``, this function unflatten into views and keep """Unlike ``_unflatten_params``, this function unflatten into views and keep
self.flat_param unchanged. self.flat_param unchanged.
""" """
assert self.is_flattened assert self.is_flattened
ps = self.get_param_views() ps = self.get_param_views()
...@@ -459,7 +459,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -459,7 +459,7 @@ class FlattenParamsWrapper(nn.Module):
return self.module(*inputs, **kwinputs) return self.module(*inputs, **kwinputs)
def get_param_views(self, external_data_list: Optional[List[Optional[Tensor]]] = None) -> Iterator[Tensor]: def get_param_views(self, external_data_list: Optional[List[Optional[Tensor]]] = None) -> Iterator[Tensor]:
""" Used to get a generator over all views from a list of external data list. """ """Used to get a generator over all views from a list of external data list."""
params = self.flat_params params = self.flat_params
if external_data_list is None: if external_data_list is None:
external_data_list = [None] * len(params) external_data_list = [None] * len(params)
......
...@@ -120,19 +120,17 @@ class GradBucket(Bucket): ...@@ -120,19 +120,17 @@ class GradBucket(Bucket):
self.callback: Optional[Callable[[Any], None]] = None self.callback: Optional[Callable[[Any], None]] = None
def reset_checked_in(self) -> None: def reset_checked_in(self) -> None:
""" Reset the counter of the parameter grads which have been checked in """Reset the counter of the parameter grads which have been checked in"""
"""
self.params_checked_in = 0 self.params_checked_in = 0
self.sent = False self.sent = False
@property @property
def all_checked_in(self) -> bool: def all_checked_in(self) -> bool:
""" Have all the expected gradient check-in happened ?""" """Have all the expected gradient check-in happened ?"""
return len(self._params) == self.params_checked_in return len(self._params) == self.params_checked_in
def can_add_grad_view(self, param: torch.Tensor) -> bool: def can_add_grad_view(self, param: torch.Tensor) -> bool:
""" Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ? """Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?"""
"""
return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids
def to( # type: ignore def to( # type: ignore
......
...@@ -117,7 +117,11 @@ class Top2Gate(torch.nn.Module): ...@@ -117,7 +117,11 @@ class Top2Gate(torch.nn.Module):
wg: torch.nn.Linear wg: torch.nn.Linear
def __init__(self, model_dim: int, num_experts: int,) -> None: def __init__(
self,
model_dim: int,
num_experts: int,
) -> None:
super().__init__() super().__init__()
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
......
...@@ -222,7 +222,10 @@ class AsyncPipe(Module): ...@@ -222,7 +222,10 @@ class AsyncPipe(Module):
) )
def instantiate_partition( def instantiate_partition(
self, module: Union[nn.Sequential, List[LazyModule]], balance: List[int], group: torch.distributed.ProcessGroup, self,
module: Union[nn.Sequential, List[LazyModule]],
balance: List[int],
group: torch.distributed.ProcessGroup,
) -> List[ModuleWrapper]: ) -> List[ModuleWrapper]:
layers: NamedModules = OrderedDict() layers: NamedModules = OrderedDict()
......
...@@ -64,7 +64,13 @@ class AsyncPipeline: ...@@ -64,7 +64,13 @@ class AsyncPipeline:
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))] skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))]
rank = self.group.rank() rank = self.group.rank()
event_loop = AsyncEventLoop(self.partitions, self.group, self.transport, self.training, self.checkpoint_stop,) event_loop = AsyncEventLoop(
self.partitions,
self.group,
self.transport,
self.training,
self.checkpoint_stop,
)
if rank == 0 and not self.final_stage: if rank == 0 and not self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event head") logging.debug(f"{torch.distributed.get_rank()}: entered event head")
event_loop.event_loop_head(batches, skip_trackers, event) event_loop.event_loop_head(batches, skip_trackers, event)
......
...@@ -169,7 +169,10 @@ class AsyncRecvOperator(torch.autograd.Function): ...@@ -169,7 +169,10 @@ class AsyncRecvOperator(torch.autograd.Function):
@staticmethod @staticmethod
# type: ignore # type: ignore
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]: def backward(
ctx,
*grad: Tensor,
) -> Tuple[Optional[Tensor], ...]:
ranks = get_pipeline_parallel_ranks() ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank() this_rank = torch.distributed.get_rank()
body = AsyncMessageBody( body = AsyncMessageBody(
...@@ -177,7 +180,11 @@ class AsyncRecvOperator(torch.autograd.Function): ...@@ -177,7 +180,11 @@ class AsyncRecvOperator(torch.autograd.Function):
) )
ctx.transport.send_message( ctx.transport.send_message(
PipeMessage( PipeMessage(
this_rank, ranks[ctx.args.source.stage], queue_name=ctx.queue_name, args=body, tensors=tuple(grad), this_rank,
ranks[ctx.args.source.stage],
queue_name=ctx.queue_name,
args=body,
tensors=tuple(grad),
), ),
sync=True, sync=True,
) )
...@@ -242,7 +249,12 @@ class AsyncEventLoop: ...@@ -242,7 +249,12 @@ class AsyncEventLoop:
to the next stage in the pipeline if needed.""" to the next stage in the pipeline if needed."""
task = create_task( task = create_task(
self.checkpoint_stop, batch.index, self.group.rank(), batch, partition.module, skip_trackers, self.checkpoint_stop,
batch.index,
self.group.rank(),
batch,
partition.module,
skip_trackers,
) )
result = task.compute() result = task.compute()
task.finalize(result) task.finalize(result)
...@@ -267,7 +279,7 @@ class AsyncEventLoop: ...@@ -267,7 +279,7 @@ class AsyncEventLoop:
# All batches saved in `activations` are generated by AutogradWithoutActivations, # All batches saved in `activations` are generated by AutogradWithoutActivations,
# so we store the gradients in `grad_from_pipeline` so it will be used # so we store the gradients in `grad_from_pipeline` so it will be used
# during the backward pass # during the backward pass
batch.tensor.grad_fn.grad_from_pipeline = tuple(recvd_grads.tensors) # type: ignore batch.tensor.grad_fn.grad_from_pipeline = tuple(recvd_grads.tensors)
batch.tensor.backward(retain_graph=True) batch.tensor.backward(retain_graph=True)
def run_invocations_on_batch( def run_invocations_on_batch(
......
...@@ -37,7 +37,10 @@ Tensors = Tuple[Tensor, ...] ...@@ -37,7 +37,10 @@ Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors] TensorOrTensors = Union[Tensor, Tensors]
def layerwise_sandbox(module: nn.Sequential, device: torch.device,) -> Generator[nn.Module, None, None]: def layerwise_sandbox(
module: nn.Sequential,
device: torch.device,
) -> Generator[nn.Module, None, None]:
"""Copies layers for ease to profile. It doesn't modify the given """Copies layers for ease to profile. It doesn't modify the given
module. module.
""" """
...@@ -54,7 +57,12 @@ def detach(batch: Batch) -> None: ...@@ -54,7 +57,12 @@ def detach(batch: Batch) -> None:
batch[i] = x.detach().requires_grad_(x.requires_grad) batch[i] = x.detach().requires_grad_(x.requires_grad)
def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float, device: torch.device,) -> List[int]: def profile_times(
module: nn.Sequential,
sample: TensorOrTensors,
timeout: float,
device: torch.device,
) -> List[int]:
"""Profiles elapsed times per layer.""" """Profiles elapsed times per layer."""
if any(p.grad is not None for p in module.parameters()): if any(p.grad is not None for p in module.parameters()):
raise ValueError("some parameter already has gradient") raise ValueError("some parameter already has gradient")
...@@ -95,7 +103,11 @@ def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float ...@@ -95,7 +103,11 @@ def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float
def profile_sizes( def profile_sizes(
module: nn.Sequential, input: TensorOrTensors, chunks: int, param_scale: float, device: torch.device, module: nn.Sequential,
input: TensorOrTensors,
chunks: int,
param_scale: float,
device: torch.device,
) -> List[int]: ) -> List[int]:
"""Profiles CUDA memory usage per layer.""" """Profiles CUDA memory usage per layer."""
if device.type != "cuda": if device.type != "cuda":
......
...@@ -197,7 +197,10 @@ class Context: ...@@ -197,7 +197,10 @@ class Context:
pass pass
def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None: def save_rng_states(
device: torch.device,
rng_states: Deque[RNGStates],
) -> None:
""":meth:`Checkpoint.forward` captures the current PyTorch's random number """:meth:`Checkpoint.forward` captures the current PyTorch's random number
generator states at CPU and GPU to reuse in :meth:`Recompute.backward`. generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.
...@@ -217,7 +220,10 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None ...@@ -217,7 +220,10 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None
@contextmanager @contextmanager
def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]: def restore_rng_states(
device: torch.device,
rng_states: Deque[RNGStates],
) -> Generator[None, None, None]:
""":meth:`Recompute.backward` restores the random number generator states """:meth:`Recompute.backward` restores the random number generator states
captured by :func:`save_rng_states` within its context. captured by :func:`save_rng_states` within its context.
...@@ -264,7 +270,10 @@ class Checkpoint(torch.autograd.Function): ...@@ -264,7 +270,10 @@ class Checkpoint(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover def backward(
ctx: Context,
*grad_output: Tensor,
) -> Tuple[Optional[Tensor], ...]: # pragma: no cover
output, input_leaf = ctx.recomputed.pop() output, input_leaf = ctx.recomputed.pop()
if isinstance(output, tuple): if isinstance(output, tuple):
......
...@@ -45,7 +45,12 @@ class Copy(torch.autograd.Function): ...@@ -45,7 +45,12 @@ class Copy(torch.autograd.Function):
@staticmethod @staticmethod
# type: ignore # type: ignore
def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors: def forward(
ctx: Context,
prev_stream: AbstractStream,
next_stream: AbstractStream,
*input: Tensor,
) -> Tensors:
ctx.prev_stream = prev_stream ctx.prev_stream = prev_stream
ctx.next_stream = next_stream ctx.next_stream = next_stream
...@@ -66,7 +71,10 @@ class Copy(torch.autograd.Function): ...@@ -66,7 +71,10 @@ class Copy(torch.autograd.Function):
return tuple(output) return tuple(output)
@staticmethod @staticmethod
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: def backward(
ctx: Context,
*grad_output: Tensor,
) -> Tuple[Optional[Tensor], ...]:
prev_stream = ctx.prev_stream prev_stream = ctx.prev_stream
next_stream = ctx.next_stream next_stream = ctx.next_stream
...@@ -98,7 +106,12 @@ class Wait(torch.autograd.Function): ...@@ -98,7 +106,12 @@ class Wait(torch.autograd.Function):
@staticmethod @staticmethod
# type: ignore # type: ignore
def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors: def forward(
ctx: Context,
prev_stream: AbstractStream,
next_stream: AbstractStream,
*input: Tensor,
) -> Tensors:
ctx.prev_stream = prev_stream ctx.prev_stream = prev_stream
ctx.next_stream = next_stream ctx.next_stream = next_stream
...@@ -107,7 +120,10 @@ class Wait(torch.autograd.Function): ...@@ -107,7 +120,10 @@ class Wait(torch.autograd.Function):
return tuple(x.detach() for x in input) return tuple(x.detach() for x in input)
@staticmethod @staticmethod
def backward(ctx: Context, *grad_input: Tensor,) -> Tuple[Optional[Tensor], ...]: def backward(
ctx: Context,
*grad_input: Tensor,
) -> Tuple[Optional[Tensor], ...]:
prev_stream = ctx.prev_stream prev_stream = ctx.prev_stream
next_stream = ctx.next_stream next_stream = ctx.next_stream
......
...@@ -106,7 +106,9 @@ class BalanceError(ValueError): ...@@ -106,7 +106,9 @@ class BalanceError(ValueError):
def split_module( def split_module(
module: nn.Sequential, balance: Iterable[int], devices: List[torch.device], module: nn.Sequential,
balance: Iterable[int],
devices: List[torch.device],
) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]: ) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:
"""Splits a module into multiple partitions. """Splits a module into multiple partitions.
...@@ -350,12 +352,12 @@ class Pipe(Module): ...@@ -350,12 +352,12 @@ class Pipe(Module):
raise MOVING_DENIED raise MOVING_DENIED
def to(self, *args: Any, **kwargs: Any) -> "Pipe": def to(self, *args: Any, **kwargs: Any) -> "Pipe":
""" Deny these usages: """Deny these usages:
- to(device[, dtype, non_blocking]) - to(device[, dtype, non_blocking])
- to(tensor[, non_blocking]) - to(tensor[, non_blocking])
But allow this: But allow this:
- to(dtype[, non_blocking])""" - to(dtype[, non_blocking])"""
if "device" in kwargs or "tensor" in kwargs: if "device" in kwargs or "tensor" in kwargs:
raise MOVING_DENIED raise MOVING_DENIED
......
...@@ -130,7 +130,10 @@ class Pipeline: ...@@ -130,7 +130,10 @@ class Pipeline:
self.compute(batches, schedule, skip_trackers) self.compute(batches, schedule, skip_trackers)
def fence( def fence(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], self,
batches: List[Batch],
schedule: List[Tuple[int, int]],
skip_trackers: List[SkipTrackerThroughPotals],
) -> None: ) -> None:
"""Copies micro-batches after computation for the previous """Copies micro-batches after computation for the previous
micro-batches. micro-batches.
...@@ -155,7 +158,10 @@ class Pipeline: ...@@ -155,7 +158,10 @@ class Pipeline:
copy(batches[i], prev_stream, next_stream) copy(batches[i], prev_stream, next_stream)
def compute( def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], self,
batches: List[Batch],
schedule: List[Tuple[int, int]],
skip_trackers: List[SkipTrackerThroughPotals],
) -> None: ) -> None:
"""Runs tasks with synchronization to copy streams.""" """Runs tasks with synchronization to copy streams."""
partitions = self.partitions partitions = self.partitions
......
...@@ -39,7 +39,11 @@ class SkipLayout: ...@@ -39,7 +39,11 @@ class SkipLayout:
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...] # Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_src_partition: List[List[Tuple[int, Namespace, str]]] by_src_partition: List[List[Tuple[int, Namespace, str]]]
def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None: def __init__(
self,
num_partitions: int,
skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],
) -> None:
# The skip routes are already indexed by 'ns, name'. # The skip routes are already indexed by 'ns, name'.
self.by_ns_name = skip_routes self.by_ns_name = skip_routes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment