Unverified Commit 7d7edf6d authored by Anupam Bhatnagar's avatar Anupam Bhatnagar Committed by GitHub
Browse files

Setup pre-commit github action and apply pre-commit to all files (#849)

* adding pre-commit files

* applying pre-commit to all files

* adding no-strict-optional argument to mypy in circle ci config

* fix typo

* updating python versions

* [skip ci] remove extra args

* adding python 3.9

* [skip ci] set pre-commit version in requirements-dev.txt

* set CACHE_VERSION

* move linters from circleci to github actions

* update python version

* update python version in benchmarks_2

* moving to python 3.9.7
parent 6f3931a4
......@@ -16,7 +16,7 @@ DEBUG = False
def _next_power_of_2_or_max(n: int, max_n: int) -> int:
""" Return the smallest power of 2 greater than or equal to n, with a limit.
"""Return the smallest power of 2 greater than or equal to n, with a limit.
Useful when used in splitting a tensor into chunks with power-of-2 sizes.
"""
......@@ -50,7 +50,7 @@ def _reshape_inputs(input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Te
def get_data(
shape: Tuple[Tuple[int, int], Tuple[int, int]], dtype: torch.dtype = torch.float16, device: str = "cuda"
) -> Tuple[torch.Tensor, nn.Parameter, torch.Tensor]:
""" Utility function for getting some tensors for testing and benchmarking."""
"""Utility function for getting some tensors for testing and benchmarking."""
(tokens, d1), (d2, vocabs) = shape
assert d1 == d2
input = torch.rand(tokens, d1, device=device, dtype=dtype).requires_grad_(True)
......@@ -66,7 +66,7 @@ def get_data(
class BaselineSoftmax(nn.Module):
""" Baseline softmax that does an output linear projection and a softmax.
"""Baseline softmax that does an output linear projection and a softmax.
This is intended to be used with an embedding layer with shared weights.
......@@ -94,7 +94,7 @@ class BaselineSoftmax(nn.Module):
self.log_softmax = log_softmax
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore
""" Forward function that computes softmax output with the input and target."""
"""Forward function that computes softmax output with the input and target."""
assert isinstance(input, torch.Tensor)
assert isinstance(target, torch.Tensor)
input, target = _reshape_inputs(input, target)
......@@ -111,7 +111,7 @@ class BaselineSoftmax(nn.Module):
class BaselineSoftmaxNllLoss(BaselineSoftmax):
""" Baseline that does an output projection, a softmax & a NLL loss (cross-entropy).
"""Baseline that does an output projection, a softmax & a NLL loss (cross-entropy).
See BaselineSoftmax above. Constructor is the same. Only difference is in the
forward function.
......@@ -333,9 +333,7 @@ class BackwardTriggerFn(torch.autograd.Function):
"""A backward trigger function."""
@staticmethod
def forward( # type: ignore
ctx: Any, w: torch.Tensor, trigger_tensor: torch.Tensor
) -> torch.Tensor:
def forward(ctx: Any, w: torch.Tensor, trigger_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
"""We take a weight tensor and the trigger as inputs and output the weight directly."""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG trigger fwd")
......@@ -388,7 +386,7 @@ class BackwardTrigger(nn.Module):
class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
""" Fused fc + softmax + nll_loss in a tiled fashion.
"""Fused fc + softmax + nll_loss in a tiled fashion.
This uses much less memory but is quite a bit slower.
......
......@@ -80,7 +80,11 @@ class ModelShard(nn.Module):
"""
def __init__(
self, cpu_model_shard: nn.Module, device: torch.device, offload_device: torch.device, index: int,
self,
cpu_model_shard: nn.Module,
device: torch.device,
offload_device: torch.device,
index: int,
):
super().__init__()
self.model_shard = cpu_model_shard
......@@ -457,17 +461,25 @@ class OffloadModel(nn.Module):
# This is already sharded using the auto shard functinality.
for i, m in enumerate(model):
self.model_slices.append(
ModelShard(cpu_model_shard=m, device=device, offload_device=offload_device, index=i,)
ModelShard(
cpu_model_shard=m,
device=device,
offload_device=offload_device,
index=i,
)
)
else:
# Slice the model into roughly equivalent sequential shards.
splits = _split(model, num_slices)
splits = _split(model, num_slices) # type: ignore
for i, split in enumerate(splits):
# Add one model handling this slice
self.model_slices.append(
ModelShard(
cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device, index=i,
cpu_model_shard=nn.Sequential(*split),
device=device,
offload_device=offload_device,
index=i,
)
)
......
......@@ -200,6 +200,8 @@ class DynamicLossScaler(object):
def state_dict(self) -> Optional[Dict[str, float]]:
if self.loss_scale is not None:
return {"loss_scale": self.loss_scale}
else:
return None
def load_state_dict(self, state_dict: Dict[str, float]) -> None:
if "loss_scale" in state_dict:
......
......@@ -35,7 +35,8 @@ class TraceForwardEvent(NamedTuple):
@classmethod
def from_dict(cls, serialized: Dict[str, Any]) -> "TraceForwardEvent":
return TraceForwardEvent(
memory_diff=serialized["memory_diff"], memory_activations=serialized["memory_activations"],
memory_diff=serialized["memory_diff"],
memory_activations=serialized["memory_activations"],
)
......@@ -410,7 +411,8 @@ class LayerwiseMemoryTracker:
all_gathered=self._last_all_gather_memory,
cumul_all_gathered=sum(self._cumul_all_gather_memory),
event=TraceForwardEvent(
memory_diff=allocated - self._memory_pre_forward, memory_activations=activations,
memory_diff=allocated - self._memory_pre_forward,
memory_activations=activations,
),
)
)
......@@ -593,7 +595,9 @@ def suggest_checkpoint_location(
# Then map it back to module names
return SuggestedCheckpoints(
max_memory=max_memory, split_modules=[modules[i] for i in reset_indices], all_modules=modules,
max_memory=max_memory,
split_modules=[modules[i] for i in reset_indices],
all_modules=modules,
)
......@@ -609,7 +613,9 @@ def _assert_visualisation_library_installed() -> None:
def compare_memory_traces_in_plot(
memory_traces_by_job: Dict[str, List[LayerMemoryTrace]], figsize: Tuple[int, int] = (16, 20), capture: bool = False,
memory_traces_by_job: Dict[str, List[LayerMemoryTrace]],
figsize: Tuple[int, int] = (16, 20),
capture: bool = False,
) -> Optional[Any]:
"""
Create a plot of the memory allocation over time during the forward/backward
......@@ -684,7 +690,7 @@ class _MemoryGraphCreator:
ax.plot(x, y_forward, x, y_backward, label=job_name)
max_index = np.argmax(allocated_memory)
max_trace = memory_traces[max_index] # type: ignore
max_trace = memory_traces[max_index]
max_module = ".".join([n for n in max_trace.module_name.split(".") if not n.startswith("_")])
max_phase = "fwd" if max_trace.is_forward else "bwd"
ax.set_ylim([None, max_trace.allocated * 1.1])
......@@ -722,7 +728,7 @@ class _MemoryGraphCreator:
# Adding the name of the layer with max cumulative all_gathered memory
max_index = np.argmax(cumul_gathered_memory)
max_trace = memory_traces[max_index] # type: ignore
max_trace = memory_traces[max_index]
max_module = ".".join([n for n in max_trace.module_name.split(".") if not n.startswith("_")])
ax.set_ylim([None, max_trace.cumul_all_gathered * 1.1])
x_text, y_text = max(0, max_index * 0.8), max_trace.cumul_all_gathered * 1.04 # type: ignore
......
......@@ -95,7 +95,10 @@ def is_recomputing() -> bool:
return thread_local.is_recomputing
def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False,) -> nn.Module:
def checkpoint_wrapper(
module: nn.Module,
offload_to_cpu: bool = False,
) -> nn.Module:
"""
A friendlier wrapper for performing activation checkpointing.
......
......@@ -448,7 +448,7 @@ class FullyShardedDataParallel(nn.Module):
return self._fsdp_wrapped_module
def append_shared_param(self, p: Parameter) -> None:
""" Add a param that's already owned by another FSDP wrapper.
"""Add a param that's already owned by another FSDP wrapper.
.. warning:: This is experimental!
......@@ -1248,7 +1248,7 @@ class FullyShardedDataParallel(nn.Module):
m._reducer = self._reducer
def _setup_output_hook_list(self) -> None:
""" set up a list to avoid registering pre-backward hooks
"""set up a list to avoid registering pre-backward hooks
incorrectly.
"""
assert self._is_root, "This should only be called on the root"
......
......@@ -31,7 +31,7 @@ def _trainable(param: torch.Tensor) -> bool:
class ShardedDataParallel(nn.Module):
""" Wrap the model, and reduce the gradients to the right rank during the backward pass.
"""Wrap the model, and reduce the gradients to the right rank during the backward pass.
- the partition is given by the sharded optimizer
- wrap the base model with a model which knows where to reduce each gradient
......@@ -224,7 +224,10 @@ class ShardedDataParallel(nn.Module):
return self.module(*inputs, **kwargs)
def to( # type: ignore
self, device: Optional[torch.device], dtype: Optional[torch.dtype] = None, non_blocking: bool = False,
self,
device: Optional[torch.device],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> "ShardedDataParallel":
"""
Moves and/or casts the parameters and buffers.
......@@ -273,7 +276,7 @@ class ShardedDataParallel(nn.Module):
self.refresh_trainable()
def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """
"""If the module trainability has changed, update all the assumptions"""
# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
if functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False):
......@@ -628,7 +631,10 @@ class ShardedDataParallel(nn.Module):
self._work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self._process_group, async_op=True,
tensor=bucket.buffer,
dst=bucket.destination,
group=self._process_group,
async_op=True,
),
callback=None,
)
......
......@@ -37,12 +37,12 @@ if TYPE_CHECKING:
class FlatParameter(nn.Parameter):
""" A parameter that is initialized from a list of parameters and can be
"""A parameter that is initialized from a list of parameters and can be
turned into a list of views as needed.
"""
def __new__(cls, params: Sequence[nn.Parameter], requires_grad: bool = True) -> "FlatParameter":
""" Make an object using the parent's __new__ function. """
"""Make an object using the parent's __new__ function."""
# A empty of non-list input doesn't make sense.
if not isinstance(params, (list, tuple)) or len(params) == 0:
......@@ -66,7 +66,7 @@ class FlatParameter(nn.Parameter):
return super(FlatParameter, cls).__new__(cls, data, requires_grad=requires_grad)
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
""" Initialize the _param_numels and _param_shapes lists. """
"""Initialize the _param_numels and _param_shapes lists."""
self._param_numels = [p.numel() for p in params]
assert self.numel() <= sum(
self._param_numels
......@@ -78,7 +78,7 @@ class FlatParameter(nn.Parameter):
self._shared_param_infos: List[Tuple[str, str, nn.Module, str, nn.Module, str]] = []
def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Tensor]:
""" Return a generator of views that map to the original parameters. """
"""Return a generator of views that map to the original parameters."""
# Note, self.data could be sharded, so its numel is <= to the sum.
assert self.data.numel() <= sum(
self._param_numels
......@@ -96,14 +96,14 @@ class FlatParameter(nn.Parameter):
return names, self._param_shapes, self._param_numels
def __setstate__(self, state: Tuple[Any, Any, Any, Any]) -> None:
""" Use by pickle to set the internal states. """
"""Use by pickle to set the internal states."""
(self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos) = state
assert self.numel() <= sum(
self._param_numels
), f"Incorrect pickling {self.numel()} vs. {sum(self._param_numels)}"
def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any]:
""" Support pickling between ranks. """
"""Support pickling between ranks."""
return (
FlatParameter, # Callable
# Args to the callable above
......@@ -228,14 +228,14 @@ class FlattenParamsWrapper(nn.Module):
@property
def module(self) -> Any:
""" Support fpw.module in case we are immitating DDP, which has .module
"""Support fpw.module in case we are immitating DDP, which has .module
property to the underlying module.
"""
return self._fpw_module
@property
def flat_param(self) -> nn.Parameter:
""" We used to support only a single flat_param. This allows us to
"""We used to support only a single flat_param. This allows us to
be backward compatible.
"""
assert len(self.flat_params) == 1, "Incorrect access to flat_param"
......@@ -246,7 +246,7 @@ class FlattenParamsWrapper(nn.Module):
) -> Tuple[
List[nn.Parameter], List[Tuple[str, nn.Module, str]], List[Tuple[str, str, nn.Module, str, nn.Module, str]]
]:
""" Build metadata for need-to-be-flatten parameters and returns a list
"""Build metadata for need-to-be-flatten parameters and returns a list
contains the need-to-be-flatten parameters.
This also returns param_infos and shared_param_infos, which
......@@ -287,7 +287,7 @@ class FlattenParamsWrapper(nn.Module):
return chain(*[p._shared_param_infos for p in self.flat_params])
def _flatten_params(self, flat_params: List[FlatParameter]) -> None:
""" Flatten the managed parameters and replaced the original
"""Flatten the managed parameters and replaced the original
attributes with views to the flat params.
"""
assert not self.is_flattened
......@@ -309,7 +309,7 @@ class FlattenParamsWrapper(nn.Module):
self._unflatten_params_as_views()
def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = None) -> None:
""" Undo flattening and create separate parameters from the already flattened
"""Undo flattening and create separate parameters from the already flattened
self.flat_param or a user supplied external data.
"""
assert self.is_flattened or external_data is not None
......@@ -336,7 +336,7 @@ class FlattenParamsWrapper(nn.Module):
self.flat_params = []
def _unflatten_params_as_views(self) -> None:
""" Unlike ``_unflatten_params``, this function unflatten into views and keep
"""Unlike ``_unflatten_params``, this function unflatten into views and keep
self.flat_param unchanged.
"""
assert self.is_flattened
......@@ -459,7 +459,7 @@ class FlattenParamsWrapper(nn.Module):
return self.module(*inputs, **kwinputs)
def get_param_views(self, external_data_list: Optional[List[Optional[Tensor]]] = None) -> Iterator[Tensor]:
""" Used to get a generator over all views from a list of external data list. """
"""Used to get a generator over all views from a list of external data list."""
params = self.flat_params
if external_data_list is None:
external_data_list = [None] * len(params)
......
......@@ -120,19 +120,17 @@ class GradBucket(Bucket):
self.callback: Optional[Callable[[Any], None]] = None
def reset_checked_in(self) -> None:
""" Reset the counter of the parameter grads which have been checked in
"""
"""Reset the counter of the parameter grads which have been checked in"""
self.params_checked_in = 0
self.sent = False
@property
def all_checked_in(self) -> bool:
""" Have all the expected gradient check-in happened ?"""
"""Have all the expected gradient check-in happened ?"""
return len(self._params) == self.params_checked_in
def can_add_grad_view(self, param: torch.Tensor) -> bool:
""" Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?
"""
"""Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?"""
return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids
def to( # type: ignore
......
......@@ -117,7 +117,11 @@ class Top2Gate(torch.nn.Module):
wg: torch.nn.Linear
def __init__(self, model_dim: int, num_experts: int,) -> None:
def __init__(
self,
model_dim: int,
num_experts: int,
) -> None:
super().__init__()
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
......
......@@ -222,7 +222,10 @@ class AsyncPipe(Module):
)
def instantiate_partition(
self, module: Union[nn.Sequential, List[LazyModule]], balance: List[int], group: torch.distributed.ProcessGroup,
self,
module: Union[nn.Sequential, List[LazyModule]],
balance: List[int],
group: torch.distributed.ProcessGroup,
) -> List[ModuleWrapper]:
layers: NamedModules = OrderedDict()
......
......@@ -64,7 +64,13 @@ class AsyncPipeline:
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))]
rank = self.group.rank()
event_loop = AsyncEventLoop(self.partitions, self.group, self.transport, self.training, self.checkpoint_stop,)
event_loop = AsyncEventLoop(
self.partitions,
self.group,
self.transport,
self.training,
self.checkpoint_stop,
)
if rank == 0 and not self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event head")
event_loop.event_loop_head(batches, skip_trackers, event)
......
......@@ -169,7 +169,10 @@ class AsyncRecvOperator(torch.autograd.Function):
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]:
def backward(
ctx,
*grad: Tensor,
) -> Tuple[Optional[Tensor], ...]:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
body = AsyncMessageBody(
......@@ -177,7 +180,11 @@ class AsyncRecvOperator(torch.autograd.Function):
)
ctx.transport.send_message(
PipeMessage(
this_rank, ranks[ctx.args.source.stage], queue_name=ctx.queue_name, args=body, tensors=tuple(grad),
this_rank,
ranks[ctx.args.source.stage],
queue_name=ctx.queue_name,
args=body,
tensors=tuple(grad),
),
sync=True,
)
......@@ -242,7 +249,12 @@ class AsyncEventLoop:
to the next stage in the pipeline if needed."""
task = create_task(
self.checkpoint_stop, batch.index, self.group.rank(), batch, partition.module, skip_trackers,
self.checkpoint_stop,
batch.index,
self.group.rank(),
batch,
partition.module,
skip_trackers,
)
result = task.compute()
task.finalize(result)
......@@ -267,7 +279,7 @@ class AsyncEventLoop:
# All batches saved in `activations` are generated by AutogradWithoutActivations,
# so we store the gradients in `grad_from_pipeline` so it will be used
# during the backward pass
batch.tensor.grad_fn.grad_from_pipeline = tuple(recvd_grads.tensors) # type: ignore
batch.tensor.grad_fn.grad_from_pipeline = tuple(recvd_grads.tensors)
batch.tensor.backward(retain_graph=True)
def run_invocations_on_batch(
......
......@@ -37,7 +37,10 @@ Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
def layerwise_sandbox(module: nn.Sequential, device: torch.device,) -> Generator[nn.Module, None, None]:
def layerwise_sandbox(
module: nn.Sequential,
device: torch.device,
) -> Generator[nn.Module, None, None]:
"""Copies layers for ease to profile. It doesn't modify the given
module.
"""
......@@ -54,7 +57,12 @@ def detach(batch: Batch) -> None:
batch[i] = x.detach().requires_grad_(x.requires_grad)
def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float, device: torch.device,) -> List[int]:
def profile_times(
module: nn.Sequential,
sample: TensorOrTensors,
timeout: float,
device: torch.device,
) -> List[int]:
"""Profiles elapsed times per layer."""
if any(p.grad is not None for p in module.parameters()):
raise ValueError("some parameter already has gradient")
......@@ -95,7 +103,11 @@ def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float
def profile_sizes(
module: nn.Sequential, input: TensorOrTensors, chunks: int, param_scale: float, device: torch.device,
module: nn.Sequential,
input: TensorOrTensors,
chunks: int,
param_scale: float,
device: torch.device,
) -> List[int]:
"""Profiles CUDA memory usage per layer."""
if device.type != "cuda":
......
......@@ -197,7 +197,10 @@ class Context:
pass
def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None:
def save_rng_states(
device: torch.device,
rng_states: Deque[RNGStates],
) -> None:
""":meth:`Checkpoint.forward` captures the current PyTorch's random number
generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.
......@@ -217,7 +220,10 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None
@contextmanager
def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]:
def restore_rng_states(
device: torch.device,
rng_states: Deque[RNGStates],
) -> Generator[None, None, None]:
""":meth:`Recompute.backward` restores the random number generator states
captured by :func:`save_rng_states` within its context.
......@@ -264,7 +270,10 @@ class Checkpoint(torch.autograd.Function):
return output
@staticmethod
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover
def backward(
ctx: Context,
*grad_output: Tensor,
) -> Tuple[Optional[Tensor], ...]: # pragma: no cover
output, input_leaf = ctx.recomputed.pop()
if isinstance(output, tuple):
......
......@@ -45,7 +45,12 @@ class Copy(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors:
def forward(
ctx: Context,
prev_stream: AbstractStream,
next_stream: AbstractStream,
*input: Tensor,
) -> Tensors:
ctx.prev_stream = prev_stream
ctx.next_stream = next_stream
......@@ -66,7 +71,10 @@ class Copy(torch.autograd.Function):
return tuple(output)
@staticmethod
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]:
def backward(
ctx: Context,
*grad_output: Tensor,
) -> Tuple[Optional[Tensor], ...]:
prev_stream = ctx.prev_stream
next_stream = ctx.next_stream
......@@ -98,7 +106,12 @@ class Wait(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors:
def forward(
ctx: Context,
prev_stream: AbstractStream,
next_stream: AbstractStream,
*input: Tensor,
) -> Tensors:
ctx.prev_stream = prev_stream
ctx.next_stream = next_stream
......@@ -107,7 +120,10 @@ class Wait(torch.autograd.Function):
return tuple(x.detach() for x in input)
@staticmethod
def backward(ctx: Context, *grad_input: Tensor,) -> Tuple[Optional[Tensor], ...]:
def backward(
ctx: Context,
*grad_input: Tensor,
) -> Tuple[Optional[Tensor], ...]:
prev_stream = ctx.prev_stream
next_stream = ctx.next_stream
......
......@@ -106,7 +106,9 @@ class BalanceError(ValueError):
def split_module(
module: nn.Sequential, balance: Iterable[int], devices: List[torch.device],
module: nn.Sequential,
balance: Iterable[int],
devices: List[torch.device],
) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:
"""Splits a module into multiple partitions.
......@@ -350,7 +352,7 @@ class Pipe(Module):
raise MOVING_DENIED
def to(self, *args: Any, **kwargs: Any) -> "Pipe":
""" Deny these usages:
"""Deny these usages:
- to(device[, dtype, non_blocking])
- to(tensor[, non_blocking])
......
......@@ -130,7 +130,10 @@ class Pipeline:
self.compute(batches, schedule, skip_trackers)
def fence(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
self,
batches: List[Batch],
schedule: List[Tuple[int, int]],
skip_trackers: List[SkipTrackerThroughPotals],
) -> None:
"""Copies micro-batches after computation for the previous
micro-batches.
......@@ -155,7 +158,10 @@ class Pipeline:
copy(batches[i], prev_stream, next_stream)
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
self,
batches: List[Batch],
schedule: List[Tuple[int, int]],
skip_trackers: List[SkipTrackerThroughPotals],
) -> None:
"""Runs tasks with synchronization to copy streams."""
partitions = self.partitions
......
......@@ -39,7 +39,11 @@ class SkipLayout:
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_src_partition: List[List[Tuple[int, Namespace, str]]]
def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None:
def __init__(
self,
num_partitions: int,
skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],
) -> None:
# The skip routes are already indexed by 'ns, name'.
self.by_ns_name = skip_routes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment