"vscode:/vscode.git/clone" did not exist on "ed72e9212620d4de10fbe476f0b7af2ab94e4cd7"
Unverified Commit 7d7edf6d authored by Anupam Bhatnagar's avatar Anupam Bhatnagar Committed by GitHub
Browse files

Setup pre-commit github action and apply pre-commit to all files (#849)

* adding pre-commit files

* applying pre-commit to all files

* adding no-strict-optional argument to mypy in circle ci config

* fix typo

* updating python versions

* [skip ci] remove extra args

* adding python 3.9

* [skip ci] set pre-commit version in requirements-dev.txt

* set CACHE_VERSION

* move linters from circleci to github actions

* update python version

* update python version in benchmarks_2

* moving to python 3.9.7
parent 6f3931a4
...@@ -16,7 +16,7 @@ DEBUG = False ...@@ -16,7 +16,7 @@ DEBUG = False
def _next_power_of_2_or_max(n: int, max_n: int) -> int: def _next_power_of_2_or_max(n: int, max_n: int) -> int:
""" Return the smallest power of 2 greater than or equal to n, with a limit. """Return the smallest power of 2 greater than or equal to n, with a limit.
Useful when used in splitting a tensor into chunks with power-of-2 sizes. Useful when used in splitting a tensor into chunks with power-of-2 sizes.
""" """
...@@ -50,7 +50,7 @@ def _reshape_inputs(input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Te ...@@ -50,7 +50,7 @@ def _reshape_inputs(input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Te
def get_data( def get_data(
shape: Tuple[Tuple[int, int], Tuple[int, int]], dtype: torch.dtype = torch.float16, device: str = "cuda" shape: Tuple[Tuple[int, int], Tuple[int, int]], dtype: torch.dtype = torch.float16, device: str = "cuda"
) -> Tuple[torch.Tensor, nn.Parameter, torch.Tensor]: ) -> Tuple[torch.Tensor, nn.Parameter, torch.Tensor]:
""" Utility function for getting some tensors for testing and benchmarking.""" """Utility function for getting some tensors for testing and benchmarking."""
(tokens, d1), (d2, vocabs) = shape (tokens, d1), (d2, vocabs) = shape
assert d1 == d2 assert d1 == d2
input = torch.rand(tokens, d1, device=device, dtype=dtype).requires_grad_(True) input = torch.rand(tokens, d1, device=device, dtype=dtype).requires_grad_(True)
...@@ -66,7 +66,7 @@ def get_data( ...@@ -66,7 +66,7 @@ def get_data(
class BaselineSoftmax(nn.Module): class BaselineSoftmax(nn.Module):
""" Baseline softmax that does an output linear projection and a softmax. """Baseline softmax that does an output linear projection and a softmax.
This is intended to be used with an embedding layer with shared weights. This is intended to be used with an embedding layer with shared weights.
...@@ -94,7 +94,7 @@ class BaselineSoftmax(nn.Module): ...@@ -94,7 +94,7 @@ class BaselineSoftmax(nn.Module):
self.log_softmax = log_softmax self.log_softmax = log_softmax
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore
""" Forward function that computes softmax output with the input and target.""" """Forward function that computes softmax output with the input and target."""
assert isinstance(input, torch.Tensor) assert isinstance(input, torch.Tensor)
assert isinstance(target, torch.Tensor) assert isinstance(target, torch.Tensor)
input, target = _reshape_inputs(input, target) input, target = _reshape_inputs(input, target)
...@@ -111,7 +111,7 @@ class BaselineSoftmax(nn.Module): ...@@ -111,7 +111,7 @@ class BaselineSoftmax(nn.Module):
class BaselineSoftmaxNllLoss(BaselineSoftmax): class BaselineSoftmaxNllLoss(BaselineSoftmax):
""" Baseline that does an output projection, a softmax & a NLL loss (cross-entropy). """Baseline that does an output projection, a softmax & a NLL loss (cross-entropy).
See BaselineSoftmax above. Constructor is the same. Only difference is in the See BaselineSoftmax above. Constructor is the same. Only difference is in the
forward function. forward function.
...@@ -333,9 +333,7 @@ class BackwardTriggerFn(torch.autograd.Function): ...@@ -333,9 +333,7 @@ class BackwardTriggerFn(torch.autograd.Function):
"""A backward trigger function.""" """A backward trigger function."""
@staticmethod @staticmethod
def forward( # type: ignore def forward(ctx: Any, w: torch.Tensor, trigger_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
ctx: Any, w: torch.Tensor, trigger_tensor: torch.Tensor
) -> torch.Tensor:
"""We take a weight tensor and the trigger as inputs and output the weight directly.""" """We take a weight tensor and the trigger as inputs and output the weight directly."""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0: if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG trigger fwd") print("DEBUG trigger fwd")
...@@ -388,7 +386,7 @@ class BackwardTrigger(nn.Module): ...@@ -388,7 +386,7 @@ class BackwardTrigger(nn.Module):
class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
""" Fused fc + softmax + nll_loss in a tiled fashion. """Fused fc + softmax + nll_loss in a tiled fashion.
This uses much less memory but is quite a bit slower. This uses much less memory but is quite a bit slower.
......
...@@ -80,7 +80,11 @@ class ModelShard(nn.Module): ...@@ -80,7 +80,11 @@ class ModelShard(nn.Module):
""" """
def __init__( def __init__(
self, cpu_model_shard: nn.Module, device: torch.device, offload_device: torch.device, index: int, self,
cpu_model_shard: nn.Module,
device: torch.device,
offload_device: torch.device,
index: int,
): ):
super().__init__() super().__init__()
self.model_shard = cpu_model_shard self.model_shard = cpu_model_shard
...@@ -457,17 +461,25 @@ class OffloadModel(nn.Module): ...@@ -457,17 +461,25 @@ class OffloadModel(nn.Module):
# This is already sharded using the auto shard functinality. # This is already sharded using the auto shard functinality.
for i, m in enumerate(model): for i, m in enumerate(model):
self.model_slices.append( self.model_slices.append(
ModelShard(cpu_model_shard=m, device=device, offload_device=offload_device, index=i,) ModelShard(
cpu_model_shard=m,
device=device,
offload_device=offload_device,
index=i,
)
) )
else: else:
# Slice the model into roughly equivalent sequential shards. # Slice the model into roughly equivalent sequential shards.
splits = _split(model, num_slices) splits = _split(model, num_slices) # type: ignore
for i, split in enumerate(splits): for i, split in enumerate(splits):
# Add one model handling this slice # Add one model handling this slice
self.model_slices.append( self.model_slices.append(
ModelShard( ModelShard(
cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device, index=i, cpu_model_shard=nn.Sequential(*split),
device=device,
offload_device=offload_device,
index=i,
) )
) )
......
...@@ -200,6 +200,8 @@ class DynamicLossScaler(object): ...@@ -200,6 +200,8 @@ class DynamicLossScaler(object):
def state_dict(self) -> Optional[Dict[str, float]]: def state_dict(self) -> Optional[Dict[str, float]]:
if self.loss_scale is not None: if self.loss_scale is not None:
return {"loss_scale": self.loss_scale} return {"loss_scale": self.loss_scale}
else:
return None
def load_state_dict(self, state_dict: Dict[str, float]) -> None: def load_state_dict(self, state_dict: Dict[str, float]) -> None:
if "loss_scale" in state_dict: if "loss_scale" in state_dict:
......
...@@ -35,7 +35,8 @@ class TraceForwardEvent(NamedTuple): ...@@ -35,7 +35,8 @@ class TraceForwardEvent(NamedTuple):
@classmethod @classmethod
def from_dict(cls, serialized: Dict[str, Any]) -> "TraceForwardEvent": def from_dict(cls, serialized: Dict[str, Any]) -> "TraceForwardEvent":
return TraceForwardEvent( return TraceForwardEvent(
memory_diff=serialized["memory_diff"], memory_activations=serialized["memory_activations"], memory_diff=serialized["memory_diff"],
memory_activations=serialized["memory_activations"],
) )
...@@ -410,7 +411,8 @@ class LayerwiseMemoryTracker: ...@@ -410,7 +411,8 @@ class LayerwiseMemoryTracker:
all_gathered=self._last_all_gather_memory, all_gathered=self._last_all_gather_memory,
cumul_all_gathered=sum(self._cumul_all_gather_memory), cumul_all_gathered=sum(self._cumul_all_gather_memory),
event=TraceForwardEvent( event=TraceForwardEvent(
memory_diff=allocated - self._memory_pre_forward, memory_activations=activations, memory_diff=allocated - self._memory_pre_forward,
memory_activations=activations,
), ),
) )
) )
...@@ -593,7 +595,9 @@ def suggest_checkpoint_location( ...@@ -593,7 +595,9 @@ def suggest_checkpoint_location(
# Then map it back to module names # Then map it back to module names
return SuggestedCheckpoints( return SuggestedCheckpoints(
max_memory=max_memory, split_modules=[modules[i] for i in reset_indices], all_modules=modules, max_memory=max_memory,
split_modules=[modules[i] for i in reset_indices],
all_modules=modules,
) )
...@@ -609,7 +613,9 @@ def _assert_visualisation_library_installed() -> None: ...@@ -609,7 +613,9 @@ def _assert_visualisation_library_installed() -> None:
def compare_memory_traces_in_plot( def compare_memory_traces_in_plot(
memory_traces_by_job: Dict[str, List[LayerMemoryTrace]], figsize: Tuple[int, int] = (16, 20), capture: bool = False, memory_traces_by_job: Dict[str, List[LayerMemoryTrace]],
figsize: Tuple[int, int] = (16, 20),
capture: bool = False,
) -> Optional[Any]: ) -> Optional[Any]:
""" """
Create a plot of the memory allocation over time during the forward/backward Create a plot of the memory allocation over time during the forward/backward
...@@ -684,7 +690,7 @@ class _MemoryGraphCreator: ...@@ -684,7 +690,7 @@ class _MemoryGraphCreator:
ax.plot(x, y_forward, x, y_backward, label=job_name) ax.plot(x, y_forward, x, y_backward, label=job_name)
max_index = np.argmax(allocated_memory) max_index = np.argmax(allocated_memory)
max_trace = memory_traces[max_index] # type: ignore max_trace = memory_traces[max_index]
max_module = ".".join([n for n in max_trace.module_name.split(".") if not n.startswith("_")]) max_module = ".".join([n for n in max_trace.module_name.split(".") if not n.startswith("_")])
max_phase = "fwd" if max_trace.is_forward else "bwd" max_phase = "fwd" if max_trace.is_forward else "bwd"
ax.set_ylim([None, max_trace.allocated * 1.1]) ax.set_ylim([None, max_trace.allocated * 1.1])
...@@ -722,7 +728,7 @@ class _MemoryGraphCreator: ...@@ -722,7 +728,7 @@ class _MemoryGraphCreator:
# Adding the name of the layer with max cumulative all_gathered memory # Adding the name of the layer with max cumulative all_gathered memory
max_index = np.argmax(cumul_gathered_memory) max_index = np.argmax(cumul_gathered_memory)
max_trace = memory_traces[max_index] # type: ignore max_trace = memory_traces[max_index]
max_module = ".".join([n for n in max_trace.module_name.split(".") if not n.startswith("_")]) max_module = ".".join([n for n in max_trace.module_name.split(".") if not n.startswith("_")])
ax.set_ylim([None, max_trace.cumul_all_gathered * 1.1]) ax.set_ylim([None, max_trace.cumul_all_gathered * 1.1])
x_text, y_text = max(0, max_index * 0.8), max_trace.cumul_all_gathered * 1.04 # type: ignore x_text, y_text = max(0, max_index * 0.8), max_trace.cumul_all_gathered * 1.04 # type: ignore
......
...@@ -95,7 +95,10 @@ def is_recomputing() -> bool: ...@@ -95,7 +95,10 @@ def is_recomputing() -> bool:
return thread_local.is_recomputing return thread_local.is_recomputing
def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False,) -> nn.Module: def checkpoint_wrapper(
module: nn.Module,
offload_to_cpu: bool = False,
) -> nn.Module:
""" """
A friendlier wrapper for performing activation checkpointing. A friendlier wrapper for performing activation checkpointing.
......
...@@ -448,7 +448,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -448,7 +448,7 @@ class FullyShardedDataParallel(nn.Module):
return self._fsdp_wrapped_module return self._fsdp_wrapped_module
def append_shared_param(self, p: Parameter) -> None: def append_shared_param(self, p: Parameter) -> None:
""" Add a param that's already owned by another FSDP wrapper. """Add a param that's already owned by another FSDP wrapper.
.. warning:: This is experimental! .. warning:: This is experimental!
...@@ -1248,7 +1248,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1248,7 +1248,7 @@ class FullyShardedDataParallel(nn.Module):
m._reducer = self._reducer m._reducer = self._reducer
def _setup_output_hook_list(self) -> None: def _setup_output_hook_list(self) -> None:
""" set up a list to avoid registering pre-backward hooks """set up a list to avoid registering pre-backward hooks
incorrectly. incorrectly.
""" """
assert self._is_root, "This should only be called on the root" assert self._is_root, "This should only be called on the root"
......
...@@ -31,7 +31,7 @@ def _trainable(param: torch.Tensor) -> bool: ...@@ -31,7 +31,7 @@ def _trainable(param: torch.Tensor) -> bool:
class ShardedDataParallel(nn.Module): class ShardedDataParallel(nn.Module):
""" Wrap the model, and reduce the gradients to the right rank during the backward pass. """Wrap the model, and reduce the gradients to the right rank during the backward pass.
- the partition is given by the sharded optimizer - the partition is given by the sharded optimizer
- wrap the base model with a model which knows where to reduce each gradient - wrap the base model with a model which knows where to reduce each gradient
...@@ -224,7 +224,10 @@ class ShardedDataParallel(nn.Module): ...@@ -224,7 +224,10 @@ class ShardedDataParallel(nn.Module):
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
def to( # type: ignore def to( # type: ignore
self, device: Optional[torch.device], dtype: Optional[torch.dtype] = None, non_blocking: bool = False, self,
device: Optional[torch.device],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> "ShardedDataParallel": ) -> "ShardedDataParallel":
""" """
Moves and/or casts the parameters and buffers. Moves and/or casts the parameters and buffers.
...@@ -273,7 +276,7 @@ class ShardedDataParallel(nn.Module): ...@@ -273,7 +276,7 @@ class ShardedDataParallel(nn.Module):
self.refresh_trainable() self.refresh_trainable()
def refresh_trainable(self) -> None: def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """ """If the module trainability has changed, update all the assumptions"""
# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance) # Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
if functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False): if functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False):
...@@ -628,7 +631,10 @@ class ShardedDataParallel(nn.Module): ...@@ -628,7 +631,10 @@ class ShardedDataParallel(nn.Module):
self._work_handles.append( self._work_handles.append(
Workhandle( Workhandle(
handle=dist.reduce( handle=dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self._process_group, async_op=True, tensor=bucket.buffer,
dst=bucket.destination,
group=self._process_group,
async_op=True,
), ),
callback=None, callback=None,
) )
......
...@@ -37,12 +37,12 @@ if TYPE_CHECKING: ...@@ -37,12 +37,12 @@ if TYPE_CHECKING:
class FlatParameter(nn.Parameter): class FlatParameter(nn.Parameter):
""" A parameter that is initialized from a list of parameters and can be """A parameter that is initialized from a list of parameters and can be
turned into a list of views as needed. turned into a list of views as needed.
""" """
def __new__(cls, params: Sequence[nn.Parameter], requires_grad: bool = True) -> "FlatParameter": def __new__(cls, params: Sequence[nn.Parameter], requires_grad: bool = True) -> "FlatParameter":
""" Make an object using the parent's __new__ function. """ """Make an object using the parent's __new__ function."""
# A empty of non-list input doesn't make sense. # A empty of non-list input doesn't make sense.
if not isinstance(params, (list, tuple)) or len(params) == 0: if not isinstance(params, (list, tuple)) or len(params) == 0:
...@@ -66,7 +66,7 @@ class FlatParameter(nn.Parameter): ...@@ -66,7 +66,7 @@ class FlatParameter(nn.Parameter):
return super(FlatParameter, cls).__new__(cls, data, requires_grad=requires_grad) return super(FlatParameter, cls).__new__(cls, data, requires_grad=requires_grad)
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True): def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
""" Initialize the _param_numels and _param_shapes lists. """ """Initialize the _param_numels and _param_shapes lists."""
self._param_numels = [p.numel() for p in params] self._param_numels = [p.numel() for p in params]
assert self.numel() <= sum( assert self.numel() <= sum(
self._param_numels self._param_numels
...@@ -78,7 +78,7 @@ class FlatParameter(nn.Parameter): ...@@ -78,7 +78,7 @@ class FlatParameter(nn.Parameter):
self._shared_param_infos: List[Tuple[str, str, nn.Module, str, nn.Module, str]] = [] self._shared_param_infos: List[Tuple[str, str, nn.Module, str, nn.Module, str]] = []
def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Tensor]: def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Tensor]:
""" Return a generator of views that map to the original parameters. """ """Return a generator of views that map to the original parameters."""
# Note, self.data could be sharded, so its numel is <= to the sum. # Note, self.data could be sharded, so its numel is <= to the sum.
assert self.data.numel() <= sum( assert self.data.numel() <= sum(
self._param_numels self._param_numels
...@@ -96,14 +96,14 @@ class FlatParameter(nn.Parameter): ...@@ -96,14 +96,14 @@ class FlatParameter(nn.Parameter):
return names, self._param_shapes, self._param_numels return names, self._param_shapes, self._param_numels
def __setstate__(self, state: Tuple[Any, Any, Any, Any]) -> None: def __setstate__(self, state: Tuple[Any, Any, Any, Any]) -> None:
""" Use by pickle to set the internal states. """ """Use by pickle to set the internal states."""
(self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos) = state (self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos) = state
assert self.numel() <= sum( assert self.numel() <= sum(
self._param_numels self._param_numels
), f"Incorrect pickling {self.numel()} vs. {sum(self._param_numels)}" ), f"Incorrect pickling {self.numel()} vs. {sum(self._param_numels)}"
def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any]: def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any]:
""" Support pickling between ranks. """ """Support pickling between ranks."""
return ( return (
FlatParameter, # Callable FlatParameter, # Callable
# Args to the callable above # Args to the callable above
...@@ -228,14 +228,14 @@ class FlattenParamsWrapper(nn.Module): ...@@ -228,14 +228,14 @@ class FlattenParamsWrapper(nn.Module):
@property @property
def module(self) -> Any: def module(self) -> Any:
""" Support fpw.module in case we are immitating DDP, which has .module """Support fpw.module in case we are immitating DDP, which has .module
property to the underlying module. property to the underlying module.
""" """
return self._fpw_module return self._fpw_module
@property @property
def flat_param(self) -> nn.Parameter: def flat_param(self) -> nn.Parameter:
""" We used to support only a single flat_param. This allows us to """We used to support only a single flat_param. This allows us to
be backward compatible. be backward compatible.
""" """
assert len(self.flat_params) == 1, "Incorrect access to flat_param" assert len(self.flat_params) == 1, "Incorrect access to flat_param"
...@@ -246,7 +246,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -246,7 +246,7 @@ class FlattenParamsWrapper(nn.Module):
) -> Tuple[ ) -> Tuple[
List[nn.Parameter], List[Tuple[str, nn.Module, str]], List[Tuple[str, str, nn.Module, str, nn.Module, str]] List[nn.Parameter], List[Tuple[str, nn.Module, str]], List[Tuple[str, str, nn.Module, str, nn.Module, str]]
]: ]:
""" Build metadata for need-to-be-flatten parameters and returns a list """Build metadata for need-to-be-flatten parameters and returns a list
contains the need-to-be-flatten parameters. contains the need-to-be-flatten parameters.
This also returns param_infos and shared_param_infos, which This also returns param_infos and shared_param_infos, which
...@@ -287,7 +287,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -287,7 +287,7 @@ class FlattenParamsWrapper(nn.Module):
return chain(*[p._shared_param_infos for p in self.flat_params]) return chain(*[p._shared_param_infos for p in self.flat_params])
def _flatten_params(self, flat_params: List[FlatParameter]) -> None: def _flatten_params(self, flat_params: List[FlatParameter]) -> None:
""" Flatten the managed parameters and replaced the original """Flatten the managed parameters and replaced the original
attributes with views to the flat params. attributes with views to the flat params.
""" """
assert not self.is_flattened assert not self.is_flattened
...@@ -309,7 +309,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -309,7 +309,7 @@ class FlattenParamsWrapper(nn.Module):
self._unflatten_params_as_views() self._unflatten_params_as_views()
def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = None) -> None: def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = None) -> None:
""" Undo flattening and create separate parameters from the already flattened """Undo flattening and create separate parameters from the already flattened
self.flat_param or a user supplied external data. self.flat_param or a user supplied external data.
""" """
assert self.is_flattened or external_data is not None assert self.is_flattened or external_data is not None
...@@ -336,7 +336,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -336,7 +336,7 @@ class FlattenParamsWrapper(nn.Module):
self.flat_params = [] self.flat_params = []
def _unflatten_params_as_views(self) -> None: def _unflatten_params_as_views(self) -> None:
""" Unlike ``_unflatten_params``, this function unflatten into views and keep """Unlike ``_unflatten_params``, this function unflatten into views and keep
self.flat_param unchanged. self.flat_param unchanged.
""" """
assert self.is_flattened assert self.is_flattened
...@@ -459,7 +459,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -459,7 +459,7 @@ class FlattenParamsWrapper(nn.Module):
return self.module(*inputs, **kwinputs) return self.module(*inputs, **kwinputs)
def get_param_views(self, external_data_list: Optional[List[Optional[Tensor]]] = None) -> Iterator[Tensor]: def get_param_views(self, external_data_list: Optional[List[Optional[Tensor]]] = None) -> Iterator[Tensor]:
""" Used to get a generator over all views from a list of external data list. """ """Used to get a generator over all views from a list of external data list."""
params = self.flat_params params = self.flat_params
if external_data_list is None: if external_data_list is None:
external_data_list = [None] * len(params) external_data_list = [None] * len(params)
......
...@@ -120,19 +120,17 @@ class GradBucket(Bucket): ...@@ -120,19 +120,17 @@ class GradBucket(Bucket):
self.callback: Optional[Callable[[Any], None]] = None self.callback: Optional[Callable[[Any], None]] = None
def reset_checked_in(self) -> None: def reset_checked_in(self) -> None:
""" Reset the counter of the parameter grads which have been checked in """Reset the counter of the parameter grads which have been checked in"""
"""
self.params_checked_in = 0 self.params_checked_in = 0
self.sent = False self.sent = False
@property @property
def all_checked_in(self) -> bool: def all_checked_in(self) -> bool:
""" Have all the expected gradient check-in happened ?""" """Have all the expected gradient check-in happened ?"""
return len(self._params) == self.params_checked_in return len(self._params) == self.params_checked_in
def can_add_grad_view(self, param: torch.Tensor) -> bool: def can_add_grad_view(self, param: torch.Tensor) -> bool:
""" Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ? """Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?"""
"""
return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids
def to( # type: ignore def to( # type: ignore
......
...@@ -117,7 +117,11 @@ class Top2Gate(torch.nn.Module): ...@@ -117,7 +117,11 @@ class Top2Gate(torch.nn.Module):
wg: torch.nn.Linear wg: torch.nn.Linear
def __init__(self, model_dim: int, num_experts: int,) -> None: def __init__(
self,
model_dim: int,
num_experts: int,
) -> None:
super().__init__() super().__init__()
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
......
...@@ -222,7 +222,10 @@ class AsyncPipe(Module): ...@@ -222,7 +222,10 @@ class AsyncPipe(Module):
) )
def instantiate_partition( def instantiate_partition(
self, module: Union[nn.Sequential, List[LazyModule]], balance: List[int], group: torch.distributed.ProcessGroup, self,
module: Union[nn.Sequential, List[LazyModule]],
balance: List[int],
group: torch.distributed.ProcessGroup,
) -> List[ModuleWrapper]: ) -> List[ModuleWrapper]:
layers: NamedModules = OrderedDict() layers: NamedModules = OrderedDict()
......
...@@ -64,7 +64,13 @@ class AsyncPipeline: ...@@ -64,7 +64,13 @@ class AsyncPipeline:
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))] skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))]
rank = self.group.rank() rank = self.group.rank()
event_loop = AsyncEventLoop(self.partitions, self.group, self.transport, self.training, self.checkpoint_stop,) event_loop = AsyncEventLoop(
self.partitions,
self.group,
self.transport,
self.training,
self.checkpoint_stop,
)
if rank == 0 and not self.final_stage: if rank == 0 and not self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event head") logging.debug(f"{torch.distributed.get_rank()}: entered event head")
event_loop.event_loop_head(batches, skip_trackers, event) event_loop.event_loop_head(batches, skip_trackers, event)
......
...@@ -169,7 +169,10 @@ class AsyncRecvOperator(torch.autograd.Function): ...@@ -169,7 +169,10 @@ class AsyncRecvOperator(torch.autograd.Function):
@staticmethod @staticmethod
# type: ignore # type: ignore
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]: def backward(
ctx,
*grad: Tensor,
) -> Tuple[Optional[Tensor], ...]:
ranks = get_pipeline_parallel_ranks() ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank() this_rank = torch.distributed.get_rank()
body = AsyncMessageBody( body = AsyncMessageBody(
...@@ -177,7 +180,11 @@ class AsyncRecvOperator(torch.autograd.Function): ...@@ -177,7 +180,11 @@ class AsyncRecvOperator(torch.autograd.Function):
) )
ctx.transport.send_message( ctx.transport.send_message(
PipeMessage( PipeMessage(
this_rank, ranks[ctx.args.source.stage], queue_name=ctx.queue_name, args=body, tensors=tuple(grad), this_rank,
ranks[ctx.args.source.stage],
queue_name=ctx.queue_name,
args=body,
tensors=tuple(grad),
), ),
sync=True, sync=True,
) )
...@@ -242,7 +249,12 @@ class AsyncEventLoop: ...@@ -242,7 +249,12 @@ class AsyncEventLoop:
to the next stage in the pipeline if needed.""" to the next stage in the pipeline if needed."""
task = create_task( task = create_task(
self.checkpoint_stop, batch.index, self.group.rank(), batch, partition.module, skip_trackers, self.checkpoint_stop,
batch.index,
self.group.rank(),
batch,
partition.module,
skip_trackers,
) )
result = task.compute() result = task.compute()
task.finalize(result) task.finalize(result)
...@@ -267,7 +279,7 @@ class AsyncEventLoop: ...@@ -267,7 +279,7 @@ class AsyncEventLoop:
# All batches saved in `activations` are generated by AutogradWithoutActivations, # All batches saved in `activations` are generated by AutogradWithoutActivations,
# so we store the gradients in `grad_from_pipeline` so it will be used # so we store the gradients in `grad_from_pipeline` so it will be used
# during the backward pass # during the backward pass
batch.tensor.grad_fn.grad_from_pipeline = tuple(recvd_grads.tensors) # type: ignore batch.tensor.grad_fn.grad_from_pipeline = tuple(recvd_grads.tensors)
batch.tensor.backward(retain_graph=True) batch.tensor.backward(retain_graph=True)
def run_invocations_on_batch( def run_invocations_on_batch(
......
...@@ -37,7 +37,10 @@ Tensors = Tuple[Tensor, ...] ...@@ -37,7 +37,10 @@ Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors] TensorOrTensors = Union[Tensor, Tensors]
def layerwise_sandbox(module: nn.Sequential, device: torch.device,) -> Generator[nn.Module, None, None]: def layerwise_sandbox(
module: nn.Sequential,
device: torch.device,
) -> Generator[nn.Module, None, None]:
"""Copies layers for ease to profile. It doesn't modify the given """Copies layers for ease to profile. It doesn't modify the given
module. module.
""" """
...@@ -54,7 +57,12 @@ def detach(batch: Batch) -> None: ...@@ -54,7 +57,12 @@ def detach(batch: Batch) -> None:
batch[i] = x.detach().requires_grad_(x.requires_grad) batch[i] = x.detach().requires_grad_(x.requires_grad)
def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float, device: torch.device,) -> List[int]: def profile_times(
module: nn.Sequential,
sample: TensorOrTensors,
timeout: float,
device: torch.device,
) -> List[int]:
"""Profiles elapsed times per layer.""" """Profiles elapsed times per layer."""
if any(p.grad is not None for p in module.parameters()): if any(p.grad is not None for p in module.parameters()):
raise ValueError("some parameter already has gradient") raise ValueError("some parameter already has gradient")
...@@ -95,7 +103,11 @@ def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float ...@@ -95,7 +103,11 @@ def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float
def profile_sizes( def profile_sizes(
module: nn.Sequential, input: TensorOrTensors, chunks: int, param_scale: float, device: torch.device, module: nn.Sequential,
input: TensorOrTensors,
chunks: int,
param_scale: float,
device: torch.device,
) -> List[int]: ) -> List[int]:
"""Profiles CUDA memory usage per layer.""" """Profiles CUDA memory usage per layer."""
if device.type != "cuda": if device.type != "cuda":
......
...@@ -197,7 +197,10 @@ class Context: ...@@ -197,7 +197,10 @@ class Context:
pass pass
def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None: def save_rng_states(
device: torch.device,
rng_states: Deque[RNGStates],
) -> None:
""":meth:`Checkpoint.forward` captures the current PyTorch's random number """:meth:`Checkpoint.forward` captures the current PyTorch's random number
generator states at CPU and GPU to reuse in :meth:`Recompute.backward`. generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.
...@@ -217,7 +220,10 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None ...@@ -217,7 +220,10 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None
@contextmanager @contextmanager
def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]: def restore_rng_states(
device: torch.device,
rng_states: Deque[RNGStates],
) -> Generator[None, None, None]:
""":meth:`Recompute.backward` restores the random number generator states """:meth:`Recompute.backward` restores the random number generator states
captured by :func:`save_rng_states` within its context. captured by :func:`save_rng_states` within its context.
...@@ -264,7 +270,10 @@ class Checkpoint(torch.autograd.Function): ...@@ -264,7 +270,10 @@ class Checkpoint(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover def backward(
ctx: Context,
*grad_output: Tensor,
) -> Tuple[Optional[Tensor], ...]: # pragma: no cover
output, input_leaf = ctx.recomputed.pop() output, input_leaf = ctx.recomputed.pop()
if isinstance(output, tuple): if isinstance(output, tuple):
......
...@@ -45,7 +45,12 @@ class Copy(torch.autograd.Function): ...@@ -45,7 +45,12 @@ class Copy(torch.autograd.Function):
@staticmethod @staticmethod
# type: ignore # type: ignore
def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors: def forward(
ctx: Context,
prev_stream: AbstractStream,
next_stream: AbstractStream,
*input: Tensor,
) -> Tensors:
ctx.prev_stream = prev_stream ctx.prev_stream = prev_stream
ctx.next_stream = next_stream ctx.next_stream = next_stream
...@@ -66,7 +71,10 @@ class Copy(torch.autograd.Function): ...@@ -66,7 +71,10 @@ class Copy(torch.autograd.Function):
return tuple(output) return tuple(output)
@staticmethod @staticmethod
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: def backward(
ctx: Context,
*grad_output: Tensor,
) -> Tuple[Optional[Tensor], ...]:
prev_stream = ctx.prev_stream prev_stream = ctx.prev_stream
next_stream = ctx.next_stream next_stream = ctx.next_stream
...@@ -98,7 +106,12 @@ class Wait(torch.autograd.Function): ...@@ -98,7 +106,12 @@ class Wait(torch.autograd.Function):
@staticmethod @staticmethod
# type: ignore # type: ignore
def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors: def forward(
ctx: Context,
prev_stream: AbstractStream,
next_stream: AbstractStream,
*input: Tensor,
) -> Tensors:
ctx.prev_stream = prev_stream ctx.prev_stream = prev_stream
ctx.next_stream = next_stream ctx.next_stream = next_stream
...@@ -107,7 +120,10 @@ class Wait(torch.autograd.Function): ...@@ -107,7 +120,10 @@ class Wait(torch.autograd.Function):
return tuple(x.detach() for x in input) return tuple(x.detach() for x in input)
@staticmethod @staticmethod
def backward(ctx: Context, *grad_input: Tensor,) -> Tuple[Optional[Tensor], ...]: def backward(
ctx: Context,
*grad_input: Tensor,
) -> Tuple[Optional[Tensor], ...]:
prev_stream = ctx.prev_stream prev_stream = ctx.prev_stream
next_stream = ctx.next_stream next_stream = ctx.next_stream
......
...@@ -106,7 +106,9 @@ class BalanceError(ValueError): ...@@ -106,7 +106,9 @@ class BalanceError(ValueError):
def split_module( def split_module(
module: nn.Sequential, balance: Iterable[int], devices: List[torch.device], module: nn.Sequential,
balance: Iterable[int],
devices: List[torch.device],
) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]: ) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:
"""Splits a module into multiple partitions. """Splits a module into multiple partitions.
...@@ -350,7 +352,7 @@ class Pipe(Module): ...@@ -350,7 +352,7 @@ class Pipe(Module):
raise MOVING_DENIED raise MOVING_DENIED
def to(self, *args: Any, **kwargs: Any) -> "Pipe": def to(self, *args: Any, **kwargs: Any) -> "Pipe":
""" Deny these usages: """Deny these usages:
- to(device[, dtype, non_blocking]) - to(device[, dtype, non_blocking])
- to(tensor[, non_blocking]) - to(tensor[, non_blocking])
......
...@@ -130,7 +130,10 @@ class Pipeline: ...@@ -130,7 +130,10 @@ class Pipeline:
self.compute(batches, schedule, skip_trackers) self.compute(batches, schedule, skip_trackers)
def fence( def fence(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], self,
batches: List[Batch],
schedule: List[Tuple[int, int]],
skip_trackers: List[SkipTrackerThroughPotals],
) -> None: ) -> None:
"""Copies micro-batches after computation for the previous """Copies micro-batches after computation for the previous
micro-batches. micro-batches.
...@@ -155,7 +158,10 @@ class Pipeline: ...@@ -155,7 +158,10 @@ class Pipeline:
copy(batches[i], prev_stream, next_stream) copy(batches[i], prev_stream, next_stream)
def compute( def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], self,
batches: List[Batch],
schedule: List[Tuple[int, int]],
skip_trackers: List[SkipTrackerThroughPotals],
) -> None: ) -> None:
"""Runs tasks with synchronization to copy streams.""" """Runs tasks with synchronization to copy streams."""
partitions = self.partitions partitions = self.partitions
......
...@@ -39,7 +39,11 @@ class SkipLayout: ...@@ -39,7 +39,11 @@ class SkipLayout:
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...] # Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_src_partition: List[List[Tuple[int, Namespace, str]]] by_src_partition: List[List[Tuple[int, Namespace, str]]]
def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None: def __init__(
self,
num_partitions: int,
skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],
) -> None:
# The skip routes are already indexed by 'ns, name'. # The skip routes are already indexed by 'ns, name'.
self.by_ns_name = skip_routes self.by_ns_name = skip_routes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment