Unverified Commit 7d7edf6d authored by Anupam Bhatnagar's avatar Anupam Bhatnagar Committed by GitHub
Browse files

Setup pre-commit github action and apply pre-commit to all files (#849)

* adding pre-commit files

* applying pre-commit to all files

* adding no-strict-optional argument to mypy in circle ci config

* fix typo

* updating python versions

* [skip ci] remove extra args

* adding python 3.9

* [skip ci] set pre-commit version in requirements-dev.txt

* set CACHE_VERSION

* move linters from circleci to github actions

* update python version

* update python version in benchmarks_2

* moving to python 3.9.7
parent 6f3931a4
......@@ -90,7 +90,12 @@ class Portal:
return PortalOrange.apply(self, phony)
def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor:
def copy(
self,
prev_stream: AbstractStream,
next_stream: AbstractStream,
phony: Tensor,
) -> Tensor:
"""Copies the hidden tensor by a :class:`PortalCopy`.
Give a phony and use the returning phony to keep backpropagation::
......@@ -202,7 +207,10 @@ class PortalBlue(torch.autograd.Function):
@staticmethod
# type: ignore
def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]:
def backward(
ctx: Context,
grad_phony: Tensor,
) -> Tuple[None, Tensor]:
# The paired PortalOrange should keep the gradient.
grad = ctx.portal.use_grad()
return None, grad
......@@ -236,7 +244,11 @@ class PortalCopy(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(
ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,
ctx: Context,
portal: Portal,
prev_stream: AbstractStream,
next_stream: AbstractStream,
phony: Tensor,
) -> Tensor:
ctx.portal = portal
......@@ -248,7 +260,10 @@ class PortalCopy(torch.autograd.Function):
@staticmethod
# type: ignore
def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]:
def backward(
ctx: Context,
grad_phony: Tensor,
) -> Tuple[None, None, None, None]:
portal = ctx.portal
assert portal.grad is not None
......
......@@ -248,7 +248,8 @@ class Skippable(nn.Module):
# TODO(sublee): Move to above of Skippable class for better read flow.
def skippable(
stash: Iterable[str] = (), pop: Iterable[str] = (),
stash: Iterable[str] = (),
pop: Iterable[str] = (),
) -> Callable[[Type[SkippableModule]], Type[Skippable]]:
"""The decorator to define a :class:`nn.Module <torch.nn.Module>` with skip
connections. Decorated modules are called "skippable". This functionality
......
......@@ -57,7 +57,12 @@ class SkipTracker:
return self.tensors.pop((ns, name))
def copy(
self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str,
self,
batch: Batch,
prev_stream: AbstractStream,
next_stream: AbstractStream,
ns: Namespace,
name: str,
) -> None:
raise TypeError("copy is not supported for non-portal skip tensors")
......@@ -147,7 +152,12 @@ class SkipTrackerThroughPotals(SkipTracker):
return tensor
def copy(
self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str,
self,
batch: Batch,
prev_stream: AbstractStream,
next_stream: AbstractStream,
ns: Namespace,
name: str,
) -> None:
"""Copies the skip tensor in the corresponding portal. The given
micro-batch and the portal will be tied with :class:`Fork` and
......
......@@ -105,7 +105,9 @@ def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device) -> None
out_queue.put(done)
def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[OutQueue]]:
def create_workers(
devices: List[torch.device],
) -> Tuple[List[InQueue], List[OutQueue]]:
"""Spawns worker threads. A worker thread is bound to a device."""
in_queues: List[InQueue] = []
out_queues: List[OutQueue] = []
......@@ -132,7 +134,11 @@ def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[Ou
out_queue = Queue()
workers[device] = (in_queue, out_queue)
t = Thread(target=worker, args=(in_queue, out_queue, device), daemon=True,)
t = Thread(
target=worker,
args=(in_queue, out_queue, device),
daemon=True,
)
t.start()
in_queues.append(in_queue)
......@@ -160,7 +166,9 @@ def join_workers(in_queues: List[InQueue], out_queues: List[OutQueue]) -> None:
@contextmanager
def spawn_workers(devices: List[torch.device],) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]:
def spawn_workers(
devices: List[torch.device],
) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]:
try:
(in_queues, out_queues) = create_workers(devices)
yield (in_queues, out_queues)
......
......@@ -71,7 +71,11 @@ default_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict}
default_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore
def config_auto_wrap_policy(module: nn.Module, recurse: bool, unwrapped_params: int,) -> bool:
def config_auto_wrap_policy(
module: nn.Module,
recurse: bool,
unwrapped_params: int,
) -> bool:
"""Config based policy function for :func:`auto_wrap`.
Return true for a module to be wrapped if it is already tagged with
......
......@@ -187,11 +187,11 @@ class AdaScale(Optimizer):
self._hook()
def _hook(self) -> None:
""" Internal function to register the gradient hooks.
"""Internal function to register the gradient hooks.
Note, don't assume every parameter will generate a gradient (i.e. triggering the hook)
in every backward pass, which is the reason that we have ``find_unused_params`` flag
in the DDP class in ``torch.nn.parallel``.
Note, don't assume every parameter will generate a gradient (i.e. triggering the hook)
in every backward pass, which is the reason that we have ``find_unused_params`` flag
in the DDP class in ``torch.nn.parallel``.
"""
assert self._hook_handles == [], "Must run unhook first"
for idx, param_group in enumerate(self._optimizer.param_groups):
......@@ -200,23 +200,23 @@ class AdaScale(Optimizer):
self._hook_handles.append(h)
def __del__(self) -> None:
""" Unhook in case caller forgets to call unhook.
"""Unhook in case caller forgets to call unhook.
This however may not "work" since there would be circular reference
between the hook objects and this objects. In that case, neither will
get GC'ed. Calling unhook explicitly if you really want to delete
AdaScale from memory.
This however may not "work" since there would be circular reference
between the hook objects and this objects. In that case, neither will
get GC'ed. Calling unhook explicitly if you really want to delete
AdaScale from memory.
"""
self.unhook()
def unhook(self) -> None:
""" Unregister hook handles.
"""Unregister hook handles.
This is public because caller may need to call this to ensure all GPU
memory are released. Otherwise, the hook may prevent parameters from being
released from the GPU memory pool.
This is public because caller may need to call this to ensure all GPU
memory are released. Otherwise, the hook may prevent parameters from being
released from the GPU memory pool.
Internally, we use this to support ``add_param_group()`` API.
Internally, we use this to support ``add_param_group()`` API.
"""
for h in self._hook_handles:
h.remove()
......@@ -385,7 +385,9 @@ class AdaScale(Optimizer):
# it means that we are in backward pass.
if self._local_grad_sqr is None:
self._local_grad_sqr = torch.zeros(
len(self._optimizer.param_groups), device=grad.device, requires_grad=False,
len(self._optimizer.param_groups),
device=grad.device,
requires_grad=False,
)
self._local_grad_sqr[pg_idx] += grad.pow(2).sum()
......@@ -515,9 +517,9 @@ class AdaScale(Optimizer):
return res
def add_param_group(self, pg: Dict) -> None:
""" Support adding parameter groups
"""Support adding parameter groups
We need to re-size some of the state and re-register the backward hooks.
We need to re-size some of the state and re-register the backward hooks.
"""
assert self._local_grad_sqr is None, "Can't add parameter group during backward"
self._optimizer.add_param_group(pg)
......@@ -542,28 +544,32 @@ class AdaScale(Optimizer):
return self._optimizer.zero_grad()
def state_dict(self) -> Dict:
""" Proxy function to optimizer, checkpointing needs this.
"""Proxy function to optimizer, checkpointing needs this.
.. note::
.. note::
Do NOT checkpoint in the middle of gradient accumulation since
associated AdaScale internal states are not saved in the checkpoint.
Do NOT checkpoint in the middle of gradient accumulation since
associated AdaScale internal states are not saved in the checkpoint.
"""
assert self._local_grad_sqr is None, "Don't checkpoint in backward"
return self._optimizer.state_dict()
def load_state_dict(self, data: Dict) -> None:
""" Proxy function to optimizer, checkpointing needs this.
"""Proxy function to optimizer, checkpointing needs this.
.. note::
.. note::
Do NOT checkpoint in the middle of gradient accumulation since
associated AdaScale internal states are not saved in the checkpoint.
Do NOT checkpoint in the middle of gradient accumulation since
associated AdaScale internal states are not saved in the checkpoint.
"""
assert self._local_grad_sqr is None, "Don't load checkpoint in backward"
return self._optimizer.load_state_dict(data)
def set_num_gradients_to_accumulate(self, num_gradients_to_accumulate: int, update_smoothing: bool = True,) -> None:
def set_num_gradients_to_accumulate(
self,
num_gradients_to_accumulate: int,
update_smoothing: bool = True,
) -> None:
"""Set the number of gradients to accumulate to a new value.
This is experimental. This could be called while training so that
......
......@@ -292,7 +292,7 @@ class OSS(Optimizer):
if clip_coef < 1:
for device, device_params in self._per_device_params.items():
for p in filter(lambda x: x.grad is not None, device_params[self.rank]):
p.grad.detach().mul_(clip_coef.to(device)) # type: ignore # mypy trips on the filter
p.grad.detach().mul_(clip_coef.to(device))
return total_norm
......@@ -341,7 +341,9 @@ class OSS(Optimizer):
else:
obj_list = [state_to_share]
dist.broadcast_object_list(
obj_list, src=self.global_rank, group=self.group,
obj_list,
src=self.global_rank,
group=self.group,
)
else:
# Fetch the optim state from the other replicas
......@@ -355,7 +357,9 @@ class OSS(Optimizer):
else:
obj_list = [torch.tensor([0], dtype=torch.uint8, device=dist_device)]
dist.broadcast_object_list(
obj_list, src=self._local_to_global_rank[rank], group=self.group,
obj_list,
src=self._local_to_global_rank[rank],
group=self.group,
)
replica_state = obj_list[0]
......@@ -501,7 +505,7 @@ class OSS(Optimizer):
@property
def _local_params(self) -> List[torch.Tensor]:
""" Iterable which goes through the parameters that this rank owns """
"""Iterable which goes through the parameters that this rank owns"""
if self.__local_params is None:
self.__local_params = list(
chain(
......@@ -517,7 +521,7 @@ class OSS(Optimizer):
@property
def _param_to_index(self) -> Dict[int, int]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params """
"""Hash table in between parameter indices in the global optimizer scheme, and the actual params"""
if len(self.__param_to_index) == 0:
self.__param_to_index = {id(p): i for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}
......
......@@ -27,8 +27,8 @@ def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None:
"""Do a quick test in case user called FSDP without calling torch.cuda.set_device()
correctly. This can easily happen in cpu_offload case where the model resides on
the CPU.
correctly. This can easily happen in cpu_offload case where the model resides on
the CPU.
"""
if not hasattr(process_group, "allgather"):
# Likely a dummy pg for unit test, skip checking.
......@@ -47,7 +47,7 @@ def validate_process_group(device: torch.device, process_group: ProcessGroup) ->
def enable_pytorch_sync_bn(module: torch.nn.Module) -> None:
"""Call _specify_ddp_gpu_num for all pytorch SyncBN layers so that it
is happily running even without DDP. E.g. this is used by FSDP.
is happily running even without DDP. E.g. this is used by FSDP.
"""
for layer in module.modules():
if isinstance(layer, torch.nn.modules.SyncBatchNorm) and hasattr(layer, "_specify_ddp_gpu_num"):
......
......@@ -103,7 +103,10 @@ class ReduceScatterBucketer:
@torch.no_grad()
def reduce_scatter_async(
self, input_list: List[Tensor], group: ProcessGroup, callback_fn: Optional[Callable] = None,
self,
input_list: List[Tensor],
group: ProcessGroup,
callback_fn: Optional[Callable] = None,
) -> None:
"""
Reduce-scatter a list of tensors asynchronously, so smaller reductions
......
......@@ -381,7 +381,11 @@ class _Block(Base):
self.ln_1 = nn.LayerNorm(embed_dim)
self.ln_2 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads) # type: ignore
self.mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Linear(embed_dim * 4, embed_dim),)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.GELU(),
nn.Linear(embed_dim * 4, embed_dim),
)
def forward(self, *inputs: Any, **kwargs: Any) -> Tensor:
x = inputs[0]
......@@ -701,7 +705,7 @@ def in_temporary_directory() -> Generator:
@contextlib.contextmanager
def temp_files_ctx(num: int) -> Generator:
""" A context to get tempfiles and ensure they are cleaned up. """
"""A context to get tempfiles and ensure they are cleaned up."""
files = [tempfile.mkstemp()[1] for _ in range(num)]
try:
......
......@@ -12,7 +12,7 @@ import torch
def find_tensor_by_shape(target_shape: Tuple, only_param: bool = True) -> bool:
""" Find a tensor from the heap
"""Find a tensor from the heap
Args:
target_shape (tuple):
......
......@@ -27,4 +27,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from".
force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "helpers", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchtext", "torchvision"]
known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchtext", "torchvision"]
......@@ -2,4 +2,4 @@
-r requirements.txt
# For pre-commit hooks.
pre-commit
pre-commit >= 2.15.0
......@@ -6,11 +6,11 @@
# function typing with mypy.
# - if you change versions below, please make sure it is in-sync with
# .pre-commit-config.yaml for pre-commit.
black == 19.10b0
flake8 == 3.7.9
flake8-annotations == 2.6.2
isort == 5.6.4
mypy == 0.790
black == 21.10b0
flake8 == 4.0.1
flake8-annotations == 2.7.0
isort == 5.10.1
mypy == 0.910
# Tools for unit tests & coverage.
pytest == 5.4.1
......
# FairScale should only depends on torch, not things higher level than torch.
torch >= 1.7.0
torch >= 1.8.0
......@@ -7,7 +7,7 @@ from collections import namedtuple
from typing import List, Sequence
from .container import ModuleList
_ASMoutput = namedtuple('ASMoutput', ['output', 'loss'])
_ASMoutput = namedtuple('_ASMoutput', ['output', 'loss'])
class AdaptiveLogSoftmaxWithLoss(Module):
......
......@@ -42,7 +42,7 @@ class MySGD(Optimizer):
super(MySGD, self).__setstate__(state)
def step(self, closure=None):
""" Performs a single optimization step.
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
......@@ -83,7 +83,10 @@ class AMPnetDelegate(object):
class FakeDataset(Dataset):
def __init__(
self, input_dim=10, output_dim=10, total_samples=100,
self,
input_dim=10,
output_dim=10,
total_samples=100,
):
self.input_dim = input_dim
self.output_dim = output_dim
......@@ -104,7 +107,13 @@ class FakeDataset(Dataset):
@torch_spawn([2])
def async_event_loop_interleave_simple():
model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(inplace=False), nn.Linear(10, 10), nn.ReLU(inplace=False))
pipe = AMPnetPipe(module=model, balance=[2, 2], worker_map=get_worker_map(), chunks=10, checkpoint="never",)
pipe = AMPnetPipe(
module=model,
balance=[2, 2],
worker_map=get_worker_map(),
chunks=10,
checkpoint="never",
)
fake_dataset = FakeDataset()
fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0)
loss = nn.MSELoss()
......@@ -116,7 +125,13 @@ def async_event_loop_interleave_simple():
@torch_spawn([4])
def async_event_loop_interleave_hard():
model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10))
pipe = AMPnetPipe(module=model, balance=[1, 1, 1, 1], worker_map=get_worker_map(), chunks=10, checkpoint="never",)
pipe = AMPnetPipe(
module=model,
balance=[1, 1, 1, 1],
worker_map=get_worker_map(),
chunks=10,
checkpoint="never",
)
fake_dataset = FakeDataset()
fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0)
loss = nn.MSELoss()
......
......@@ -100,11 +100,19 @@ def find_memory_used_by_model(model_class: Type[nn.Module], device: torch.device
def _prepare_single_device_module(
rank, world_size, tempfile, devices: List[torch.device], slowmo_init_dict: Dict[Any, Any], global_batch_size: int,
rank,
world_size,
tempfile,
devices: List[torch.device],
slowmo_init_dict: Dict[Any, Any],
global_batch_size: int,
) -> Tuple[nn.Module, gossip.SlowMoDistributedDataParallel, torch.Tensor, torch.Tensor]:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl", init_method=f"file://{tempfile}", rank=rank, world_size=world_size,
"nccl",
init_method=f"file://{tempfile}",
rank=rank,
world_size=world_size,
)
model = Net()
slowmo_model = gossip.SlowMoDistributedDataParallel(
......@@ -145,7 +153,9 @@ def run_test_slowmo_with_slowmo_freq_1(
rank, world_size, tempfile, devices, slowmo_init_dict, global_batch_size
)
model_optimizer = torch.optim.SGD(
model.parameters(), lr=slowmo_model.slowmo_lr, momentum=slowmo_model.slowmo_momentum,
model.parameters(),
lr=slowmo_model.slowmo_lr,
momentum=slowmo_model.slowmo_momentum,
)
slowmo_model_optimizer = torch.optim.SGD(slowmo_model.module.parameters(), lr=1, momentum=0)
slowmo_model._init_global_momentum_buffers(slowmo_model_optimizer)
......@@ -261,7 +271,9 @@ def run_test_slowmo_with_slowmo_freq_ge_2(
base_lr, base_momentum = 1, 0
model_optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=base_momentum)
model_slow_momentum_optimizer = torch.optim.SGD(
model.parameters(), lr=slowmo_model.slowmo_lr, momentum=slowmo_model.slowmo_momentum,
model.parameters(),
lr=slowmo_model.slowmo_lr,
momentum=slowmo_model.slowmo_momentum,
)
slowmo_model_optimizer = torch.optim.SGD(slowmo_model.module.parameters(), lr=base_lr, momentum=base_momentum)
slowmo_model._init_global_momentum_buffers(slowmo_model_optimizer)
......@@ -329,7 +341,10 @@ def run_test_memory_usage_localsgd_with_slowmo(
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl", init_method=f"file://{tempfile}", rank=rank, world_size=world_size,
"nccl",
init_method=f"file://{tempfile}",
rank=rank,
world_size=world_size,
)
if use_gossip_data_parallel:
model: nn.Module = gossip.SlowMoDistributedDataParallel(
......@@ -540,7 +555,11 @@ def run_max_memory_used_localsgd_slowmo_memory_efficient(rank, world_size, tempf
# Memory usage when running optimization locally on a single GPU
max_memory_local = run_test_memory_usage_localsgd_with_slowmo(
rank, world_size, tempfile_1, {"localsgd_frequency": 1}, use_gossip_data_parallel=False,
rank,
world_size,
tempfile_1,
{"localsgd_frequency": 1},
use_gossip_data_parallel=False,
)
# Memory usage when running optimization using LocalSGD-SlowMo
......@@ -586,7 +605,10 @@ def run_max_memory_used_localsgd_slowmo_memory_efficient(rank, world_size, tempf
def test_max_memory_used_localsgd_slowmo_memory_efficient() -> None:
world_size = 2
spawn_for_all_world_sizes(
run_max_memory_used_localsgd_slowmo_memory_efficient, world_sizes=[world_size], args=(), deterministic=True,
run_max_memory_used_localsgd_slowmo_memory_efficient,
world_sizes=[world_size],
args=(),
deterministic=True,
)
......@@ -595,7 +617,11 @@ def run_max_memory_used_slowmo_memory_efficient(rank: int, world_size: int, temp
devices = [torch.device("cuda:" + str(i)) for i in int_devices]
max_memory_local = run_test_memory_usage_localsgd_with_slowmo(
rank, world_size, tempfile_1, {"localsgd_frequency": 1}, use_gossip_data_parallel=False,
rank,
world_size,
tempfile_1,
{"localsgd_frequency": 1},
use_gossip_data_parallel=False,
)
max_memory_slowmo = run_test_memory_usage_localsgd_with_slowmo(
rank,
......@@ -629,7 +655,10 @@ def run_max_memory_used_slowmo_memory_efficient(rank: int, world_size: int, temp
def test_max_memory_used_slowmo_memory_efficient() -> None:
world_size = 2
spawn_for_all_world_sizes(
run_max_memory_used_slowmo_memory_efficient, world_sizes=[world_size], args=(), deterministic=True,
run_max_memory_used_slowmo_memory_efficient,
world_sizes=[world_size],
args=(),
deterministic=True,
)
......@@ -638,7 +667,11 @@ def run_max_memory_used_slowmo_no_sharding(rank, world_size, tempfile_1, tempfil
devices = [torch.device("cuda:" + str(i)) for i in int_devices]
max_memory_local = run_test_memory_usage_localsgd_with_slowmo(
rank, world_size, tempfile_1, {"localsgd_frequency": 1}, use_gossip_data_parallel=False,
rank,
world_size,
tempfile_1,
{"localsgd_frequency": 1},
use_gossip_data_parallel=False,
)
max_memory_slowmo = run_test_memory_usage_localsgd_with_slowmo(
rank,
......@@ -673,7 +706,10 @@ def run_max_memory_used_slowmo_no_sharding(rank, world_size, tempfile_1, tempfil
def test_max_memory_used_slowmo_no_sharding() -> None:
world_size = 2
spawn_for_all_world_sizes(
run_max_memory_used_slowmo_no_sharding, world_sizes=[world_size], args=(), deterministic=True,
run_max_memory_used_slowmo_no_sharding,
world_sizes=[world_size],
args=(),
deterministic=True,
)
......
......@@ -62,12 +62,12 @@ def create_sequence_pipeline(
layers: List[RemoteModuleParams], balance: List[int], devices: List[str], **kwargs: Any
) -> DistributedPipeline:
"""A simple helper function to create a pipeline from list of pipeline-modules that run sequentially.
Args:
layers: list of modules. They should not be already assigned a remote-device.
balance: a list of integers how layers should be paritioned. Sum of numbers in 'balance'
should be equal to the number of layers.
devices: specification of remote device for each partition. Should be of the same length
as 'balance'.
Args:
layers: list of modules. They should not be already assigned a remote-device.
balance: a list of integers how layers should be paritioned. Sum of numbers in 'balance'
should be equal to the number of layers.
devices: specification of remote device for each partition. Should be of the same length
as 'balance'.
"""
remote_modules: List[RemoteModule] = []
index = 0
......@@ -190,7 +190,11 @@ def update(devices):
x = torch.randn(8, 4).to(device)
model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})]
pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=4, devices=devices[:2])
opt = DistributedOptimizer(torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05,)
opt = DistributedOptimizer(
torch.optim.SGD,
pipe.parameter_rrefs(),
lr=0.05,
)
losses = []
for i in range(2):
with dist_autograd.context() as context_id:
......@@ -247,7 +251,11 @@ def multi_input_multi_output_layers(devices):
assert [[0, 1], [2], [3], [4]] == extract_partitions(graph, pipe)
parameter_rrefs = pipe.parameter_rrefs()
assert len(parameter_rrefs) == 6
opt = DistributedOptimizer(torch.optim.SGD, parameter_rrefs, lr=0.05,)
opt = DistributedOptimizer(
torch.optim.SGD,
parameter_rrefs,
lr=0.05,
)
losses = []
for i in range(2):
with dist_autograd.context() as context_id:
......@@ -301,7 +309,11 @@ def auto_graph_extract(devices):
assert [[0, 1], [2], [3], [4], [5]] == partitions, f"partitions={partitions}"
parameter_rrefs = pipe.parameter_rrefs()
assert len(parameter_rrefs) == 8
opt = DistributedOptimizer(torch.optim.SGD, parameter_rrefs, lr=0.05,)
opt = DistributedOptimizer(
torch.optim.SGD,
parameter_rrefs,
lr=0.05,
)
losses = []
for i in range(2):
with dist_autograd.context() as context_id:
......
......@@ -111,7 +111,9 @@ def test_memory_tracking_ddp():
with temp_files_ctx(num=2) as sync_files:
world_size = 2
mp.spawn(
_layer_memory_tracking_ddp_worker, (sync_files, world_size), nprocs=world_size,
_layer_memory_tracking_ddp_worker,
(sync_files, world_size),
nprocs=world_size,
)
......@@ -129,7 +131,13 @@ def _layer_memory_tracking_ddp_worker(gpu_id: int, sync_files: Tuple[str, str],
# Create a simple model
torch.manual_seed(0)
torch.cuda.manual_seed(0)
model = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 10),)
model = nn.Sequential(
nn.Linear(10, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 10),
)
model = model.cuda(gpu_id)
ddp_model = DistributedDataParallel(model, device_ids=[gpu_id])
......@@ -156,7 +164,9 @@ def test_memory_tracking_fsdp():
with temp_files_ctx(num=2) as sync_files:
world_size = 2
mp.spawn(
_layer_memory_tracking_fsdp_worker, (sync_files, world_size), nprocs=world_size,
_layer_memory_tracking_fsdp_worker,
(sync_files, world_size),
nprocs=world_size,
)
......@@ -181,9 +191,17 @@ def _layer_memory_tracking_fsdp_worker(gpu_id: int, sync_files: Tuple[str, str],
model = nn.Sequential(
nn.Linear(10, 10).cuda(gpu_id),
nn.ReLU(),
FullyShardedDataParallel(nn.Linear(10, 10).cuda(gpu_id), flatten_parameters=False, process_group=group,),
FullyShardedDataParallel(
nn.Linear(10, 10).cuda(gpu_id),
flatten_parameters=False,
process_group=group,
),
nn.ReLU(),
FullyShardedDataParallel(nn.Linear(10, 10).cuda(gpu_id), flatten_parameters=True, process_group=group,),
FullyShardedDataParallel(
nn.Linear(10, 10).cuda(gpu_id),
flatten_parameters=True,
process_group=group,
),
)
model = model.cuda(gpu_id)
dist_model = FullyShardedDataParallel(model, flatten_parameters=False, process_group=group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment