Unverified Commit 7d7edf6d authored by Anupam Bhatnagar's avatar Anupam Bhatnagar Committed by GitHub
Browse files

Setup pre-commit github action and apply pre-commit to all files (#849)

* adding pre-commit files

* applying pre-commit to all files

* adding no-strict-optional argument to mypy in circle ci config

* fix typo

* updating python versions

* [skip ci] remove extra args

* adding python 3.9

* [skip ci] set pre-commit version in requirements-dev.txt

* set CACHE_VERSION

* move linters from circleci to github actions

* update python version

* update python version in benchmarks_2

* moving to python 3.9.7
parent 6f3931a4
...@@ -90,7 +90,12 @@ class Portal: ...@@ -90,7 +90,12 @@ class Portal:
return PortalOrange.apply(self, phony) return PortalOrange.apply(self, phony)
def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor: def copy(
self,
prev_stream: AbstractStream,
next_stream: AbstractStream,
phony: Tensor,
) -> Tensor:
"""Copies the hidden tensor by a :class:`PortalCopy`. """Copies the hidden tensor by a :class:`PortalCopy`.
Give a phony and use the returning phony to keep backpropagation:: Give a phony and use the returning phony to keep backpropagation::
...@@ -202,7 +207,10 @@ class PortalBlue(torch.autograd.Function): ...@@ -202,7 +207,10 @@ class PortalBlue(torch.autograd.Function):
@staticmethod @staticmethod
# type: ignore # type: ignore
def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]: def backward(
ctx: Context,
grad_phony: Tensor,
) -> Tuple[None, Tensor]:
# The paired PortalOrange should keep the gradient. # The paired PortalOrange should keep the gradient.
grad = ctx.portal.use_grad() grad = ctx.portal.use_grad()
return None, grad return None, grad
...@@ -236,7 +244,11 @@ class PortalCopy(torch.autograd.Function): ...@@ -236,7 +244,11 @@ class PortalCopy(torch.autograd.Function):
@staticmethod @staticmethod
# type: ignore # type: ignore
def forward( def forward(
ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor, ctx: Context,
portal: Portal,
prev_stream: AbstractStream,
next_stream: AbstractStream,
phony: Tensor,
) -> Tensor: ) -> Tensor:
ctx.portal = portal ctx.portal = portal
...@@ -248,7 +260,10 @@ class PortalCopy(torch.autograd.Function): ...@@ -248,7 +260,10 @@ class PortalCopy(torch.autograd.Function):
@staticmethod @staticmethod
# type: ignore # type: ignore
def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]: def backward(
ctx: Context,
grad_phony: Tensor,
) -> Tuple[None, None, None, None]:
portal = ctx.portal portal = ctx.portal
assert portal.grad is not None assert portal.grad is not None
......
...@@ -248,7 +248,8 @@ class Skippable(nn.Module): ...@@ -248,7 +248,8 @@ class Skippable(nn.Module):
# TODO(sublee): Move to above of Skippable class for better read flow. # TODO(sublee): Move to above of Skippable class for better read flow.
def skippable( def skippable(
stash: Iterable[str] = (), pop: Iterable[str] = (), stash: Iterable[str] = (),
pop: Iterable[str] = (),
) -> Callable[[Type[SkippableModule]], Type[Skippable]]: ) -> Callable[[Type[SkippableModule]], Type[Skippable]]:
"""The decorator to define a :class:`nn.Module <torch.nn.Module>` with skip """The decorator to define a :class:`nn.Module <torch.nn.Module>` with skip
connections. Decorated modules are called "skippable". This functionality connections. Decorated modules are called "skippable". This functionality
......
...@@ -57,7 +57,12 @@ class SkipTracker: ...@@ -57,7 +57,12 @@ class SkipTracker:
return self.tensors.pop((ns, name)) return self.tensors.pop((ns, name))
def copy( def copy(
self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, self,
batch: Batch,
prev_stream: AbstractStream,
next_stream: AbstractStream,
ns: Namespace,
name: str,
) -> None: ) -> None:
raise TypeError("copy is not supported for non-portal skip tensors") raise TypeError("copy is not supported for non-portal skip tensors")
...@@ -147,7 +152,12 @@ class SkipTrackerThroughPotals(SkipTracker): ...@@ -147,7 +152,12 @@ class SkipTrackerThroughPotals(SkipTracker):
return tensor return tensor
def copy( def copy(
self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, self,
batch: Batch,
prev_stream: AbstractStream,
next_stream: AbstractStream,
ns: Namespace,
name: str,
) -> None: ) -> None:
"""Copies the skip tensor in the corresponding portal. The given """Copies the skip tensor in the corresponding portal. The given
micro-batch and the portal will be tied with :class:`Fork` and micro-batch and the portal will be tied with :class:`Fork` and
......
...@@ -105,7 +105,9 @@ def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device) -> None ...@@ -105,7 +105,9 @@ def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device) -> None
out_queue.put(done) out_queue.put(done)
def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[OutQueue]]: def create_workers(
devices: List[torch.device],
) -> Tuple[List[InQueue], List[OutQueue]]:
"""Spawns worker threads. A worker thread is bound to a device.""" """Spawns worker threads. A worker thread is bound to a device."""
in_queues: List[InQueue] = [] in_queues: List[InQueue] = []
out_queues: List[OutQueue] = [] out_queues: List[OutQueue] = []
...@@ -132,7 +134,11 @@ def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[Ou ...@@ -132,7 +134,11 @@ def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[Ou
out_queue = Queue() out_queue = Queue()
workers[device] = (in_queue, out_queue) workers[device] = (in_queue, out_queue)
t = Thread(target=worker, args=(in_queue, out_queue, device), daemon=True,) t = Thread(
target=worker,
args=(in_queue, out_queue, device),
daemon=True,
)
t.start() t.start()
in_queues.append(in_queue) in_queues.append(in_queue)
...@@ -160,7 +166,9 @@ def join_workers(in_queues: List[InQueue], out_queues: List[OutQueue]) -> None: ...@@ -160,7 +166,9 @@ def join_workers(in_queues: List[InQueue], out_queues: List[OutQueue]) -> None:
@contextmanager @contextmanager
def spawn_workers(devices: List[torch.device],) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]: def spawn_workers(
devices: List[torch.device],
) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]:
try: try:
(in_queues, out_queues) = create_workers(devices) (in_queues, out_queues) = create_workers(devices)
yield (in_queues, out_queues) yield (in_queues, out_queues)
......
...@@ -71,7 +71,11 @@ default_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} ...@@ -71,7 +71,11 @@ default_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict}
default_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore default_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore
def config_auto_wrap_policy(module: nn.Module, recurse: bool, unwrapped_params: int,) -> bool: def config_auto_wrap_policy(
module: nn.Module,
recurse: bool,
unwrapped_params: int,
) -> bool:
"""Config based policy function for :func:`auto_wrap`. """Config based policy function for :func:`auto_wrap`.
Return true for a module to be wrapped if it is already tagged with Return true for a module to be wrapped if it is already tagged with
......
...@@ -187,11 +187,11 @@ class AdaScale(Optimizer): ...@@ -187,11 +187,11 @@ class AdaScale(Optimizer):
self._hook() self._hook()
def _hook(self) -> None: def _hook(self) -> None:
""" Internal function to register the gradient hooks. """Internal function to register the gradient hooks.
Note, don't assume every parameter will generate a gradient (i.e. triggering the hook) Note, don't assume every parameter will generate a gradient (i.e. triggering the hook)
in every backward pass, which is the reason that we have ``find_unused_params`` flag in every backward pass, which is the reason that we have ``find_unused_params`` flag
in the DDP class in ``torch.nn.parallel``. in the DDP class in ``torch.nn.parallel``.
""" """
assert self._hook_handles == [], "Must run unhook first" assert self._hook_handles == [], "Must run unhook first"
for idx, param_group in enumerate(self._optimizer.param_groups): for idx, param_group in enumerate(self._optimizer.param_groups):
...@@ -200,23 +200,23 @@ class AdaScale(Optimizer): ...@@ -200,23 +200,23 @@ class AdaScale(Optimizer):
self._hook_handles.append(h) self._hook_handles.append(h)
def __del__(self) -> None: def __del__(self) -> None:
""" Unhook in case caller forgets to call unhook. """Unhook in case caller forgets to call unhook.
This however may not "work" since there would be circular reference This however may not "work" since there would be circular reference
between the hook objects and this objects. In that case, neither will between the hook objects and this objects. In that case, neither will
get GC'ed. Calling unhook explicitly if you really want to delete get GC'ed. Calling unhook explicitly if you really want to delete
AdaScale from memory. AdaScale from memory.
""" """
self.unhook() self.unhook()
def unhook(self) -> None: def unhook(self) -> None:
""" Unregister hook handles. """Unregister hook handles.
This is public because caller may need to call this to ensure all GPU This is public because caller may need to call this to ensure all GPU
memory are released. Otherwise, the hook may prevent parameters from being memory are released. Otherwise, the hook may prevent parameters from being
released from the GPU memory pool. released from the GPU memory pool.
Internally, we use this to support ``add_param_group()`` API. Internally, we use this to support ``add_param_group()`` API.
""" """
for h in self._hook_handles: for h in self._hook_handles:
h.remove() h.remove()
...@@ -385,7 +385,9 @@ class AdaScale(Optimizer): ...@@ -385,7 +385,9 @@ class AdaScale(Optimizer):
# it means that we are in backward pass. # it means that we are in backward pass.
if self._local_grad_sqr is None: if self._local_grad_sqr is None:
self._local_grad_sqr = torch.zeros( self._local_grad_sqr = torch.zeros(
len(self._optimizer.param_groups), device=grad.device, requires_grad=False, len(self._optimizer.param_groups),
device=grad.device,
requires_grad=False,
) )
self._local_grad_sqr[pg_idx] += grad.pow(2).sum() self._local_grad_sqr[pg_idx] += grad.pow(2).sum()
...@@ -515,9 +517,9 @@ class AdaScale(Optimizer): ...@@ -515,9 +517,9 @@ class AdaScale(Optimizer):
return res return res
def add_param_group(self, pg: Dict) -> None: def add_param_group(self, pg: Dict) -> None:
""" Support adding parameter groups """Support adding parameter groups
We need to re-size some of the state and re-register the backward hooks. We need to re-size some of the state and re-register the backward hooks.
""" """
assert self._local_grad_sqr is None, "Can't add parameter group during backward" assert self._local_grad_sqr is None, "Can't add parameter group during backward"
self._optimizer.add_param_group(pg) self._optimizer.add_param_group(pg)
...@@ -542,28 +544,32 @@ class AdaScale(Optimizer): ...@@ -542,28 +544,32 @@ class AdaScale(Optimizer):
return self._optimizer.zero_grad() return self._optimizer.zero_grad()
def state_dict(self) -> Dict: def state_dict(self) -> Dict:
""" Proxy function to optimizer, checkpointing needs this. """Proxy function to optimizer, checkpointing needs this.
.. note:: .. note::
Do NOT checkpoint in the middle of gradient accumulation since Do NOT checkpoint in the middle of gradient accumulation since
associated AdaScale internal states are not saved in the checkpoint. associated AdaScale internal states are not saved in the checkpoint.
""" """
assert self._local_grad_sqr is None, "Don't checkpoint in backward" assert self._local_grad_sqr is None, "Don't checkpoint in backward"
return self._optimizer.state_dict() return self._optimizer.state_dict()
def load_state_dict(self, data: Dict) -> None: def load_state_dict(self, data: Dict) -> None:
""" Proxy function to optimizer, checkpointing needs this. """Proxy function to optimizer, checkpointing needs this.
.. note:: .. note::
Do NOT checkpoint in the middle of gradient accumulation since Do NOT checkpoint in the middle of gradient accumulation since
associated AdaScale internal states are not saved in the checkpoint. associated AdaScale internal states are not saved in the checkpoint.
""" """
assert self._local_grad_sqr is None, "Don't load checkpoint in backward" assert self._local_grad_sqr is None, "Don't load checkpoint in backward"
return self._optimizer.load_state_dict(data) return self._optimizer.load_state_dict(data)
def set_num_gradients_to_accumulate(self, num_gradients_to_accumulate: int, update_smoothing: bool = True,) -> None: def set_num_gradients_to_accumulate(
self,
num_gradients_to_accumulate: int,
update_smoothing: bool = True,
) -> None:
"""Set the number of gradients to accumulate to a new value. """Set the number of gradients to accumulate to a new value.
This is experimental. This could be called while training so that This is experimental. This could be called while training so that
......
...@@ -292,7 +292,7 @@ class OSS(Optimizer): ...@@ -292,7 +292,7 @@ class OSS(Optimizer):
if clip_coef < 1: if clip_coef < 1:
for device, device_params in self._per_device_params.items(): for device, device_params in self._per_device_params.items():
for p in filter(lambda x: x.grad is not None, device_params[self.rank]): for p in filter(lambda x: x.grad is not None, device_params[self.rank]):
p.grad.detach().mul_(clip_coef.to(device)) # type: ignore # mypy trips on the filter p.grad.detach().mul_(clip_coef.to(device))
return total_norm return total_norm
...@@ -341,7 +341,9 @@ class OSS(Optimizer): ...@@ -341,7 +341,9 @@ class OSS(Optimizer):
else: else:
obj_list = [state_to_share] obj_list = [state_to_share]
dist.broadcast_object_list( dist.broadcast_object_list(
obj_list, src=self.global_rank, group=self.group, obj_list,
src=self.global_rank,
group=self.group,
) )
else: else:
# Fetch the optim state from the other replicas # Fetch the optim state from the other replicas
...@@ -355,7 +357,9 @@ class OSS(Optimizer): ...@@ -355,7 +357,9 @@ class OSS(Optimizer):
else: else:
obj_list = [torch.tensor([0], dtype=torch.uint8, device=dist_device)] obj_list = [torch.tensor([0], dtype=torch.uint8, device=dist_device)]
dist.broadcast_object_list( dist.broadcast_object_list(
obj_list, src=self._local_to_global_rank[rank], group=self.group, obj_list,
src=self._local_to_global_rank[rank],
group=self.group,
) )
replica_state = obj_list[0] replica_state = obj_list[0]
...@@ -501,7 +505,7 @@ class OSS(Optimizer): ...@@ -501,7 +505,7 @@ class OSS(Optimizer):
@property @property
def _local_params(self) -> List[torch.Tensor]: def _local_params(self) -> List[torch.Tensor]:
""" Iterable which goes through the parameters that this rank owns """ """Iterable which goes through the parameters that this rank owns"""
if self.__local_params is None: if self.__local_params is None:
self.__local_params = list( self.__local_params = list(
chain( chain(
...@@ -517,7 +521,7 @@ class OSS(Optimizer): ...@@ -517,7 +521,7 @@ class OSS(Optimizer):
@property @property
def _param_to_index(self) -> Dict[int, int]: def _param_to_index(self) -> Dict[int, int]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params """ """Hash table in between parameter indices in the global optimizer scheme, and the actual params"""
if len(self.__param_to_index) == 0: if len(self.__param_to_index) == 0:
self.__param_to_index = {id(p): i for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))} self.__param_to_index = {id(p): i for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}
......
...@@ -27,8 +27,8 @@ def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]: ...@@ -27,8 +27,8 @@ def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None: def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None:
"""Do a quick test in case user called FSDP without calling torch.cuda.set_device() """Do a quick test in case user called FSDP without calling torch.cuda.set_device()
correctly. This can easily happen in cpu_offload case where the model resides on correctly. This can easily happen in cpu_offload case where the model resides on
the CPU. the CPU.
""" """
if not hasattr(process_group, "allgather"): if not hasattr(process_group, "allgather"):
# Likely a dummy pg for unit test, skip checking. # Likely a dummy pg for unit test, skip checking.
...@@ -47,7 +47,7 @@ def validate_process_group(device: torch.device, process_group: ProcessGroup) -> ...@@ -47,7 +47,7 @@ def validate_process_group(device: torch.device, process_group: ProcessGroup) ->
def enable_pytorch_sync_bn(module: torch.nn.Module) -> None: def enable_pytorch_sync_bn(module: torch.nn.Module) -> None:
"""Call _specify_ddp_gpu_num for all pytorch SyncBN layers so that it """Call _specify_ddp_gpu_num for all pytorch SyncBN layers so that it
is happily running even without DDP. E.g. this is used by FSDP. is happily running even without DDP. E.g. this is used by FSDP.
""" """
for layer in module.modules(): for layer in module.modules():
if isinstance(layer, torch.nn.modules.SyncBatchNorm) and hasattr(layer, "_specify_ddp_gpu_num"): if isinstance(layer, torch.nn.modules.SyncBatchNorm) and hasattr(layer, "_specify_ddp_gpu_num"):
......
...@@ -103,7 +103,10 @@ class ReduceScatterBucketer: ...@@ -103,7 +103,10 @@ class ReduceScatterBucketer:
@torch.no_grad() @torch.no_grad()
def reduce_scatter_async( def reduce_scatter_async(
self, input_list: List[Tensor], group: ProcessGroup, callback_fn: Optional[Callable] = None, self,
input_list: List[Tensor],
group: ProcessGroup,
callback_fn: Optional[Callable] = None,
) -> None: ) -> None:
""" """
Reduce-scatter a list of tensors asynchronously, so smaller reductions Reduce-scatter a list of tensors asynchronously, so smaller reductions
......
...@@ -381,7 +381,11 @@ class _Block(Base): ...@@ -381,7 +381,11 @@ class _Block(Base):
self.ln_1 = nn.LayerNorm(embed_dim) self.ln_1 = nn.LayerNorm(embed_dim)
self.ln_2 = nn.LayerNorm(embed_dim) self.ln_2 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads) # type: ignore self.attn = nn.MultiheadAttention(embed_dim, num_heads) # type: ignore
self.mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Linear(embed_dim * 4, embed_dim),) self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.GELU(),
nn.Linear(embed_dim * 4, embed_dim),
)
def forward(self, *inputs: Any, **kwargs: Any) -> Tensor: def forward(self, *inputs: Any, **kwargs: Any) -> Tensor:
x = inputs[0] x = inputs[0]
...@@ -701,7 +705,7 @@ def in_temporary_directory() -> Generator: ...@@ -701,7 +705,7 @@ def in_temporary_directory() -> Generator:
@contextlib.contextmanager @contextlib.contextmanager
def temp_files_ctx(num: int) -> Generator: def temp_files_ctx(num: int) -> Generator:
""" A context to get tempfiles and ensure they are cleaned up. """ """A context to get tempfiles and ensure they are cleaned up."""
files = [tempfile.mkstemp()[1] for _ in range(num)] files = [tempfile.mkstemp()[1] for _ in range(num)]
try: try:
......
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
def find_tensor_by_shape(target_shape: Tuple, only_param: bool = True) -> bool: def find_tensor_by_shape(target_shape: Tuple, only_param: bool = True) -> bool:
""" Find a tensor from the heap """Find a tensor from the heap
Args: Args:
target_shape (tuple): target_shape (tuple):
......
...@@ -27,4 +27,4 @@ use_parentheses = true ...@@ -27,4 +27,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"] skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from". # Don't split "import" and "from".
force_sort_within_sections = true force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "helpers", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchtext", "torchvision"] known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchtext", "torchvision"]
...@@ -2,4 +2,4 @@ ...@@ -2,4 +2,4 @@
-r requirements.txt -r requirements.txt
# For pre-commit hooks. # For pre-commit hooks.
pre-commit pre-commit >= 2.15.0
...@@ -6,11 +6,11 @@ ...@@ -6,11 +6,11 @@
# function typing with mypy. # function typing with mypy.
# - if you change versions below, please make sure it is in-sync with # - if you change versions below, please make sure it is in-sync with
# .pre-commit-config.yaml for pre-commit. # .pre-commit-config.yaml for pre-commit.
black == 19.10b0 black == 21.10b0
flake8 == 3.7.9 flake8 == 4.0.1
flake8-annotations == 2.6.2 flake8-annotations == 2.7.0
isort == 5.6.4 isort == 5.10.1
mypy == 0.790 mypy == 0.910
# Tools for unit tests & coverage. # Tools for unit tests & coverage.
pytest == 5.4.1 pytest == 5.4.1
......
# FairScale should only depends on torch, not things higher level than torch. # FairScale should only depends on torch, not things higher level than torch.
torch >= 1.7.0 torch >= 1.8.0
...@@ -7,7 +7,7 @@ from collections import namedtuple ...@@ -7,7 +7,7 @@ from collections import namedtuple
from typing import List, Sequence from typing import List, Sequence
from .container import ModuleList from .container import ModuleList
_ASMoutput = namedtuple('ASMoutput', ['output', 'loss']) _ASMoutput = namedtuple('_ASMoutput', ['output', 'loss'])
class AdaptiveLogSoftmaxWithLoss(Module): class AdaptiveLogSoftmaxWithLoss(Module):
......
...@@ -42,7 +42,7 @@ class MySGD(Optimizer): ...@@ -42,7 +42,7 @@ class MySGD(Optimizer):
super(MySGD, self).__setstate__(state) super(MySGD, self).__setstate__(state)
def step(self, closure=None): def step(self, closure=None):
""" Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
closure (callable, optional): A closure that reevaluates the model closure (callable, optional): A closure that reevaluates the model
and returns the loss. and returns the loss.
...@@ -83,7 +83,10 @@ class AMPnetDelegate(object): ...@@ -83,7 +83,10 @@ class AMPnetDelegate(object):
class FakeDataset(Dataset): class FakeDataset(Dataset):
def __init__( def __init__(
self, input_dim=10, output_dim=10, total_samples=100, self,
input_dim=10,
output_dim=10,
total_samples=100,
): ):
self.input_dim = input_dim self.input_dim = input_dim
self.output_dim = output_dim self.output_dim = output_dim
...@@ -104,7 +107,13 @@ class FakeDataset(Dataset): ...@@ -104,7 +107,13 @@ class FakeDataset(Dataset):
@torch_spawn([2]) @torch_spawn([2])
def async_event_loop_interleave_simple(): def async_event_loop_interleave_simple():
model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(inplace=False), nn.Linear(10, 10), nn.ReLU(inplace=False)) model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(inplace=False), nn.Linear(10, 10), nn.ReLU(inplace=False))
pipe = AMPnetPipe(module=model, balance=[2, 2], worker_map=get_worker_map(), chunks=10, checkpoint="never",) pipe = AMPnetPipe(
module=model,
balance=[2, 2],
worker_map=get_worker_map(),
chunks=10,
checkpoint="never",
)
fake_dataset = FakeDataset() fake_dataset = FakeDataset()
fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0) fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0)
loss = nn.MSELoss() loss = nn.MSELoss()
...@@ -116,7 +125,13 @@ def async_event_loop_interleave_simple(): ...@@ -116,7 +125,13 @@ def async_event_loop_interleave_simple():
@torch_spawn([4]) @torch_spawn([4])
def async_event_loop_interleave_hard(): def async_event_loop_interleave_hard():
model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10)) model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10))
pipe = AMPnetPipe(module=model, balance=[1, 1, 1, 1], worker_map=get_worker_map(), chunks=10, checkpoint="never",) pipe = AMPnetPipe(
module=model,
balance=[1, 1, 1, 1],
worker_map=get_worker_map(),
chunks=10,
checkpoint="never",
)
fake_dataset = FakeDataset() fake_dataset = FakeDataset()
fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0) fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0)
loss = nn.MSELoss() loss = nn.MSELoss()
......
...@@ -100,11 +100,19 @@ def find_memory_used_by_model(model_class: Type[nn.Module], device: torch.device ...@@ -100,11 +100,19 @@ def find_memory_used_by_model(model_class: Type[nn.Module], device: torch.device
def _prepare_single_device_module( def _prepare_single_device_module(
rank, world_size, tempfile, devices: List[torch.device], slowmo_init_dict: Dict[Any, Any], global_batch_size: int, rank,
world_size,
tempfile,
devices: List[torch.device],
slowmo_init_dict: Dict[Any, Any],
global_batch_size: int,
) -> Tuple[nn.Module, gossip.SlowMoDistributedDataParallel, torch.Tensor, torch.Tensor]: ) -> Tuple[nn.Module, gossip.SlowMoDistributedDataParallel, torch.Tensor, torch.Tensor]:
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group( torch.distributed.init_process_group(
"nccl", init_method=f"file://{tempfile}", rank=rank, world_size=world_size, "nccl",
init_method=f"file://{tempfile}",
rank=rank,
world_size=world_size,
) )
model = Net() model = Net()
slowmo_model = gossip.SlowMoDistributedDataParallel( slowmo_model = gossip.SlowMoDistributedDataParallel(
...@@ -145,7 +153,9 @@ def run_test_slowmo_with_slowmo_freq_1( ...@@ -145,7 +153,9 @@ def run_test_slowmo_with_slowmo_freq_1(
rank, world_size, tempfile, devices, slowmo_init_dict, global_batch_size rank, world_size, tempfile, devices, slowmo_init_dict, global_batch_size
) )
model_optimizer = torch.optim.SGD( model_optimizer = torch.optim.SGD(
model.parameters(), lr=slowmo_model.slowmo_lr, momentum=slowmo_model.slowmo_momentum, model.parameters(),
lr=slowmo_model.slowmo_lr,
momentum=slowmo_model.slowmo_momentum,
) )
slowmo_model_optimizer = torch.optim.SGD(slowmo_model.module.parameters(), lr=1, momentum=0) slowmo_model_optimizer = torch.optim.SGD(slowmo_model.module.parameters(), lr=1, momentum=0)
slowmo_model._init_global_momentum_buffers(slowmo_model_optimizer) slowmo_model._init_global_momentum_buffers(slowmo_model_optimizer)
...@@ -261,7 +271,9 @@ def run_test_slowmo_with_slowmo_freq_ge_2( ...@@ -261,7 +271,9 @@ def run_test_slowmo_with_slowmo_freq_ge_2(
base_lr, base_momentum = 1, 0 base_lr, base_momentum = 1, 0
model_optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=base_momentum) model_optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=base_momentum)
model_slow_momentum_optimizer = torch.optim.SGD( model_slow_momentum_optimizer = torch.optim.SGD(
model.parameters(), lr=slowmo_model.slowmo_lr, momentum=slowmo_model.slowmo_momentum, model.parameters(),
lr=slowmo_model.slowmo_lr,
momentum=slowmo_model.slowmo_momentum,
) )
slowmo_model_optimizer = torch.optim.SGD(slowmo_model.module.parameters(), lr=base_lr, momentum=base_momentum) slowmo_model_optimizer = torch.optim.SGD(slowmo_model.module.parameters(), lr=base_lr, momentum=base_momentum)
slowmo_model._init_global_momentum_buffers(slowmo_model_optimizer) slowmo_model._init_global_momentum_buffers(slowmo_model_optimizer)
...@@ -329,7 +341,10 @@ def run_test_memory_usage_localsgd_with_slowmo( ...@@ -329,7 +341,10 @@ def run_test_memory_usage_localsgd_with_slowmo(
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group( torch.distributed.init_process_group(
"nccl", init_method=f"file://{tempfile}", rank=rank, world_size=world_size, "nccl",
init_method=f"file://{tempfile}",
rank=rank,
world_size=world_size,
) )
if use_gossip_data_parallel: if use_gossip_data_parallel:
model: nn.Module = gossip.SlowMoDistributedDataParallel( model: nn.Module = gossip.SlowMoDistributedDataParallel(
...@@ -540,7 +555,11 @@ def run_max_memory_used_localsgd_slowmo_memory_efficient(rank, world_size, tempf ...@@ -540,7 +555,11 @@ def run_max_memory_used_localsgd_slowmo_memory_efficient(rank, world_size, tempf
# Memory usage when running optimization locally on a single GPU # Memory usage when running optimization locally on a single GPU
max_memory_local = run_test_memory_usage_localsgd_with_slowmo( max_memory_local = run_test_memory_usage_localsgd_with_slowmo(
rank, world_size, tempfile_1, {"localsgd_frequency": 1}, use_gossip_data_parallel=False, rank,
world_size,
tempfile_1,
{"localsgd_frequency": 1},
use_gossip_data_parallel=False,
) )
# Memory usage when running optimization using LocalSGD-SlowMo # Memory usage when running optimization using LocalSGD-SlowMo
...@@ -586,7 +605,10 @@ def run_max_memory_used_localsgd_slowmo_memory_efficient(rank, world_size, tempf ...@@ -586,7 +605,10 @@ def run_max_memory_used_localsgd_slowmo_memory_efficient(rank, world_size, tempf
def test_max_memory_used_localsgd_slowmo_memory_efficient() -> None: def test_max_memory_used_localsgd_slowmo_memory_efficient() -> None:
world_size = 2 world_size = 2
spawn_for_all_world_sizes( spawn_for_all_world_sizes(
run_max_memory_used_localsgd_slowmo_memory_efficient, world_sizes=[world_size], args=(), deterministic=True, run_max_memory_used_localsgd_slowmo_memory_efficient,
world_sizes=[world_size],
args=(),
deterministic=True,
) )
...@@ -595,7 +617,11 @@ def run_max_memory_used_slowmo_memory_efficient(rank: int, world_size: int, temp ...@@ -595,7 +617,11 @@ def run_max_memory_used_slowmo_memory_efficient(rank: int, world_size: int, temp
devices = [torch.device("cuda:" + str(i)) for i in int_devices] devices = [torch.device("cuda:" + str(i)) for i in int_devices]
max_memory_local = run_test_memory_usage_localsgd_with_slowmo( max_memory_local = run_test_memory_usage_localsgd_with_slowmo(
rank, world_size, tempfile_1, {"localsgd_frequency": 1}, use_gossip_data_parallel=False, rank,
world_size,
tempfile_1,
{"localsgd_frequency": 1},
use_gossip_data_parallel=False,
) )
max_memory_slowmo = run_test_memory_usage_localsgd_with_slowmo( max_memory_slowmo = run_test_memory_usage_localsgd_with_slowmo(
rank, rank,
...@@ -629,7 +655,10 @@ def run_max_memory_used_slowmo_memory_efficient(rank: int, world_size: int, temp ...@@ -629,7 +655,10 @@ def run_max_memory_used_slowmo_memory_efficient(rank: int, world_size: int, temp
def test_max_memory_used_slowmo_memory_efficient() -> None: def test_max_memory_used_slowmo_memory_efficient() -> None:
world_size = 2 world_size = 2
spawn_for_all_world_sizes( spawn_for_all_world_sizes(
run_max_memory_used_slowmo_memory_efficient, world_sizes=[world_size], args=(), deterministic=True, run_max_memory_used_slowmo_memory_efficient,
world_sizes=[world_size],
args=(),
deterministic=True,
) )
...@@ -638,7 +667,11 @@ def run_max_memory_used_slowmo_no_sharding(rank, world_size, tempfile_1, tempfil ...@@ -638,7 +667,11 @@ def run_max_memory_used_slowmo_no_sharding(rank, world_size, tempfile_1, tempfil
devices = [torch.device("cuda:" + str(i)) for i in int_devices] devices = [torch.device("cuda:" + str(i)) for i in int_devices]
max_memory_local = run_test_memory_usage_localsgd_with_slowmo( max_memory_local = run_test_memory_usage_localsgd_with_slowmo(
rank, world_size, tempfile_1, {"localsgd_frequency": 1}, use_gossip_data_parallel=False, rank,
world_size,
tempfile_1,
{"localsgd_frequency": 1},
use_gossip_data_parallel=False,
) )
max_memory_slowmo = run_test_memory_usage_localsgd_with_slowmo( max_memory_slowmo = run_test_memory_usage_localsgd_with_slowmo(
rank, rank,
...@@ -673,7 +706,10 @@ def run_max_memory_used_slowmo_no_sharding(rank, world_size, tempfile_1, tempfil ...@@ -673,7 +706,10 @@ def run_max_memory_used_slowmo_no_sharding(rank, world_size, tempfile_1, tempfil
def test_max_memory_used_slowmo_no_sharding() -> None: def test_max_memory_used_slowmo_no_sharding() -> None:
world_size = 2 world_size = 2
spawn_for_all_world_sizes( spawn_for_all_world_sizes(
run_max_memory_used_slowmo_no_sharding, world_sizes=[world_size], args=(), deterministic=True, run_max_memory_used_slowmo_no_sharding,
world_sizes=[world_size],
args=(),
deterministic=True,
) )
......
...@@ -62,12 +62,12 @@ def create_sequence_pipeline( ...@@ -62,12 +62,12 @@ def create_sequence_pipeline(
layers: List[RemoteModuleParams], balance: List[int], devices: List[str], **kwargs: Any layers: List[RemoteModuleParams], balance: List[int], devices: List[str], **kwargs: Any
) -> DistributedPipeline: ) -> DistributedPipeline:
"""A simple helper function to create a pipeline from list of pipeline-modules that run sequentially. """A simple helper function to create a pipeline from list of pipeline-modules that run sequentially.
Args: Args:
layers: list of modules. They should not be already assigned a remote-device. layers: list of modules. They should not be already assigned a remote-device.
balance: a list of integers how layers should be paritioned. Sum of numbers in 'balance' balance: a list of integers how layers should be paritioned. Sum of numbers in 'balance'
should be equal to the number of layers. should be equal to the number of layers.
devices: specification of remote device for each partition. Should be of the same length devices: specification of remote device for each partition. Should be of the same length
as 'balance'. as 'balance'.
""" """
remote_modules: List[RemoteModule] = [] remote_modules: List[RemoteModule] = []
index = 0 index = 0
...@@ -190,7 +190,11 @@ def update(devices): ...@@ -190,7 +190,11 @@ def update(devices):
x = torch.randn(8, 4).to(device) x = torch.randn(8, 4).to(device)
model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})] model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})]
pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=4, devices=devices[:2]) pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=4, devices=devices[:2])
opt = DistributedOptimizer(torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05,) opt = DistributedOptimizer(
torch.optim.SGD,
pipe.parameter_rrefs(),
lr=0.05,
)
losses = [] losses = []
for i in range(2): for i in range(2):
with dist_autograd.context() as context_id: with dist_autograd.context() as context_id:
...@@ -247,7 +251,11 @@ def multi_input_multi_output_layers(devices): ...@@ -247,7 +251,11 @@ def multi_input_multi_output_layers(devices):
assert [[0, 1], [2], [3], [4]] == extract_partitions(graph, pipe) assert [[0, 1], [2], [3], [4]] == extract_partitions(graph, pipe)
parameter_rrefs = pipe.parameter_rrefs() parameter_rrefs = pipe.parameter_rrefs()
assert len(parameter_rrefs) == 6 assert len(parameter_rrefs) == 6
opt = DistributedOptimizer(torch.optim.SGD, parameter_rrefs, lr=0.05,) opt = DistributedOptimizer(
torch.optim.SGD,
parameter_rrefs,
lr=0.05,
)
losses = [] losses = []
for i in range(2): for i in range(2):
with dist_autograd.context() as context_id: with dist_autograd.context() as context_id:
...@@ -301,7 +309,11 @@ def auto_graph_extract(devices): ...@@ -301,7 +309,11 @@ def auto_graph_extract(devices):
assert [[0, 1], [2], [3], [4], [5]] == partitions, f"partitions={partitions}" assert [[0, 1], [2], [3], [4], [5]] == partitions, f"partitions={partitions}"
parameter_rrefs = pipe.parameter_rrefs() parameter_rrefs = pipe.parameter_rrefs()
assert len(parameter_rrefs) == 8 assert len(parameter_rrefs) == 8
opt = DistributedOptimizer(torch.optim.SGD, parameter_rrefs, lr=0.05,) opt = DistributedOptimizer(
torch.optim.SGD,
parameter_rrefs,
lr=0.05,
)
losses = [] losses = []
for i in range(2): for i in range(2):
with dist_autograd.context() as context_id: with dist_autograd.context() as context_id:
......
...@@ -111,7 +111,9 @@ def test_memory_tracking_ddp(): ...@@ -111,7 +111,9 @@ def test_memory_tracking_ddp():
with temp_files_ctx(num=2) as sync_files: with temp_files_ctx(num=2) as sync_files:
world_size = 2 world_size = 2
mp.spawn( mp.spawn(
_layer_memory_tracking_ddp_worker, (sync_files, world_size), nprocs=world_size, _layer_memory_tracking_ddp_worker,
(sync_files, world_size),
nprocs=world_size,
) )
...@@ -129,7 +131,13 @@ def _layer_memory_tracking_ddp_worker(gpu_id: int, sync_files: Tuple[str, str], ...@@ -129,7 +131,13 @@ def _layer_memory_tracking_ddp_worker(gpu_id: int, sync_files: Tuple[str, str],
# Create a simple model # Create a simple model
torch.manual_seed(0) torch.manual_seed(0)
torch.cuda.manual_seed(0) torch.cuda.manual_seed(0)
model = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 10),) model = nn.Sequential(
nn.Linear(10, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 10),
)
model = model.cuda(gpu_id) model = model.cuda(gpu_id)
ddp_model = DistributedDataParallel(model, device_ids=[gpu_id]) ddp_model = DistributedDataParallel(model, device_ids=[gpu_id])
...@@ -156,7 +164,9 @@ def test_memory_tracking_fsdp(): ...@@ -156,7 +164,9 @@ def test_memory_tracking_fsdp():
with temp_files_ctx(num=2) as sync_files: with temp_files_ctx(num=2) as sync_files:
world_size = 2 world_size = 2
mp.spawn( mp.spawn(
_layer_memory_tracking_fsdp_worker, (sync_files, world_size), nprocs=world_size, _layer_memory_tracking_fsdp_worker,
(sync_files, world_size),
nprocs=world_size,
) )
...@@ -181,9 +191,17 @@ def _layer_memory_tracking_fsdp_worker(gpu_id: int, sync_files: Tuple[str, str], ...@@ -181,9 +191,17 @@ def _layer_memory_tracking_fsdp_worker(gpu_id: int, sync_files: Tuple[str, str],
model = nn.Sequential( model = nn.Sequential(
nn.Linear(10, 10).cuda(gpu_id), nn.Linear(10, 10).cuda(gpu_id),
nn.ReLU(), nn.ReLU(),
FullyShardedDataParallel(nn.Linear(10, 10).cuda(gpu_id), flatten_parameters=False, process_group=group,), FullyShardedDataParallel(
nn.Linear(10, 10).cuda(gpu_id),
flatten_parameters=False,
process_group=group,
),
nn.ReLU(), nn.ReLU(),
FullyShardedDataParallel(nn.Linear(10, 10).cuda(gpu_id), flatten_parameters=True, process_group=group,), FullyShardedDataParallel(
nn.Linear(10, 10).cuda(gpu_id),
flatten_parameters=True,
process_group=group,
),
) )
model = model.cuda(gpu_id) model = model.cuda(gpu_id)
dist_model = FullyShardedDataParallel(model, flatten_parameters=False, process_group=group) dist_model = FullyShardedDataParallel(model, flatten_parameters=False, process_group=group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment