Unverified Commit 2d954203 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore][ci] restore 1.5 & 1.6 tests and compatibility (#306)

* tentatively fixing the cpu version of circleci jobs, now pipe tests are the last ones standing
* fixing oss backcompat, trying to fix rpc in old pytorch also
* fixing the file based init in torch 1.5
parent 6219b57b
......@@ -43,8 +43,8 @@ install_dep_15: &install_dep_15
name: Install Dependencies
command: |
sudo apt-get install -y libopenmpi-dev
pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
python -c 'import torch; print("Torch version:", torch.__version__)'
python -m torch.utils.collect_env
......@@ -53,37 +53,34 @@ install_dep_16: &install_dep_16
name: Install Dependencies
command: |
sudo apt-get install -y libopenmpi-dev
pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg
python -c 'import torch; print("Torch version:", torch.__version__)'
python -m torch.utils.collect_env
install_dep_17_cpu: &install_dep_17_cpu
install_dep_17: &install_dep_17
- run:
name: Install Dependencies
command: |
sudo apt-get install -y libopenmpi-dev
pip install --progress-bar off torch==1.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off torch==1.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg
python -c 'import torch; print("Torch version:", torch.__version__)'
python -m torch.utils.collect_env
install_dep_17_gpu: &install_dep_17_gpu
# FIXME: need to be removed when properly handling torch 1.7.1
# short term fix is to override the default pip installed torch
install_dep_17_cpu: &install_dep_17_cpu
- run:
name: Install Dependencies
command: |
sudo apt-get install -y libopenmpi-dev
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off torch==1.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg
python -c 'import torch; print("Torch version:", torch.__version__)'
python -m torch.utils.collect_env
install_repo_cpu: &install_repo_cpu
- run:
name: Install Repository
......@@ -291,7 +288,7 @@ jobs:
keys:
- cache-key-gpu17-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_17_gpu
- <<: *install_dep_17
- save_cache:
paths:
......@@ -328,7 +325,7 @@ jobs:
keys:
- cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_17_gpu
- <<: *install_dep_17
- save_cache:
paths:
......
......@@ -31,7 +31,6 @@ repos:
rev: v2.1.0
hooks:
- id: seed-isort-config
language_version: python3.6
- repo: https://github.com/pycqa/isort
rev: 5.6.4
......
......@@ -4,11 +4,11 @@
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum, auto
from threading import Event
from typing import Dict, Iterable, List, Optional, Tuple
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from torch.distributed import ProcessGroup
......
......@@ -4,11 +4,11 @@
# LICENSE file in the root directory of this source tree.
from abc import ABC
from dataclasses import dataclass
from queue import Empty as QueueEmpty
from queue import Queue
from typing import Dict, List, Optional
from dataclasses import dataclass
import torch
from fairscale.nn.model_parallel import get_pipeline_parallel_group
......
......@@ -19,12 +19,12 @@
"""The Pipe interface."""
from collections import OrderedDict
from dataclasses import dataclass, field
import itertools
import threading
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union, cast
import warnings
from dataclasses import dataclass, field
import torch
from torch import Tensor, nn
import torch.autograd
......
......@@ -3,10 +3,10 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Callable, List, Optional, Tuple, Union
from dataclasses import dataclass
import torch
from torch import Tensor, nn
......
......@@ -554,11 +554,10 @@ class AdaScale(Optimizer):
`set_scale` needs to be called to update the scale as well.
TODO (min): need a way of determine how much to increase the step size?
TODO (min): have both `set_scale` and `set_num_gradients_to_accumulate`
is hard to use and easy to make mistake. I think it is better
to specific a specify a `base_scale`. But more discussion is
needed here.
is hard to use and easy to make mistake. I think it is better
to specific a specify a `base_scale`. But more discussion is
needed here.
Args:
num_gradients_to_accumulate (int):
......
......@@ -25,6 +25,15 @@ if TYPE_CHECKING: # pragma: no cover
else:
_params_t = Any
try:
from torch.distributed import broadcast_object_list # noqa
_torch_broadcast_object = True
except ImportError:
from .utils import broadcast_object
_torch_broadcast_object = False
class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
......@@ -339,12 +348,27 @@ class OSS(Optimizer):
logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank,
)
dist.broadcast_object_list([local_cpu_state], src=self.global_rank, group=self.group)
if _torch_broadcast_object:
# torch native object broadcast
dist.broadcast_object_list([local_cpu_state], src=self.global_rank, group=self.group)
else:
# legacy compatibility for old torch versions
broadcast_object(
self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device
)
else:
global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
dist.broadcast_object_list([0], src=global_rank, group=self.group)
if _torch_broadcast_object:
dist.broadcast_object_list([0], src=global_rank, group=self.group)
else:
broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
src_rank=global_rank,
group=self.group,
dist_device=self._device,
)
def _collect_sharded_states(self) -> List[Dict[str, Any]]:
"""Collect all the state shards, in CPU memory."""
......@@ -358,16 +382,35 @@ class OSS(Optimizer):
)
# Sync with other replicas
dist.broadcast_object_list([0], src=self.global_rank, group=self.group)
if _torch_broadcast_object:
# torch native object broadcast
dist.broadcast_object_list([0], src=self.global_rank, group=self.group)
else:
# legacy compatibility for old torch versions
broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
src_rank=self.global_rank,
group=self.group,
dist_device=self._device,
)
else:
# Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank)
replica_state = [0]
dist.broadcast_object_list(replica_state, src=global_rank, group=self.group)
if _torch_broadcast_object:
replica_state_l = [0]
dist.broadcast_object_list(replica_state_l, src=global_rank, group=self.group)
replica_state = replica_state_l[0]
else:
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
src_rank=global_rank,
group=self.group,
dist_device=self._device,
)
all_states.append(
recursive_copy_to_device(replica_state[0], non_blocking=True, device=torch.device("cpu"))
recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
)
logging.debug("State from rank %s received", rank)
......
......@@ -3,10 +3,12 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import io
from typing import Any, Callable, Dict, Optional
import torch
from torch._six import container_abcs
import torch.distributed as dist
class Workhandle:
......@@ -44,3 +46,33 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
return device_val
return value
# backward compatibility - this is needed for torch 1.5 which does not expose this functionality
# FIXME: to be dropped alongside torch1.5 support, when time comes
def broadcast_object(
obj: Any, src_rank: int, group: object = dist.group.WORLD, dist_device: torch.device = torch.device("cpu")
) -> Any:
"""
Either broadcast from master to the fleet (default),
or use the src setting as the original rank.
"""
if dist.get_rank() == src_rank:
# Emit data
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
length_tensor = torch.LongTensor([len(data)]).to(dist_device)
data_send_tensor = torch.ByteTensor(data).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
else:
# Fetch from the source
length_tensor = torch.LongTensor([0]).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=dist_device)
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=dist_device)
return obj
......@@ -100,21 +100,26 @@ def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "")
.. warning: This limits the usecase to all ranks being on the same node
"""
try:
torch.distributed.rpc.shutdown()
except Exception:
pass
print(f"dist init r={rank}, world={world_size}")
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank)
url = "file://" + filename
url_rpc = "file://" + filename_rpc
if torch_version() >= (1, 6, 0):
backend = "nccl" if torch.cuda.is_available() else "gloo"
if backend == "nccl" and torch.cuda.device_count() < world_size:
logging.warning("Requested world size cannot be reached on this machine, not enough GPUs")
return False
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=url)
url_rpc = "file://" + filename_rpc
rpc.init_rpc(
f"Test{rank}",
rank=rank,
......@@ -125,7 +130,13 @@ def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "")
else:
if world_size > 1:
rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size)
# TensorPipe is not available in Torch 1.5
rpc.init_rpc(
name=f"Test{rank}",
rank=rank,
world_size=world_size,
rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(init_method=url_rpc),
)
elif torch.cuda.is_available():
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, init_method=url)
else:
......@@ -153,7 +164,7 @@ def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_
_, filename_rpc = tempfile.mkstemp()
# (lefaudeux) Let mp handle the process joining, join=False and handling context has been unstable in the past
mp.spawn(test_func, args=(world_size, filename, filename_rpc, *args), nprocs=world_size, join=True) # type: ignore
mp.spawn(test_func, args=(world_size, filename, filename_rpc, *args), nprocs=world_size, join=True)
def worker_process(
......@@ -163,6 +174,7 @@ def worker_process(
if not dist_init(rank, world_size, filename, filename_rpc):
logging.warning("failed initializing torch distributed")
teardown()
return
kwargs = {}
......@@ -195,7 +207,8 @@ def teardown() -> None:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
try:
torch.distributed.rpc.shutdown()
# torch 1.5 hangs on shutdown if waiting for all processes
torch.distributed.rpc.shutdown(graceful=False)
except Exception:
pass
......@@ -230,10 +243,10 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable:
if "OMPI_COMM_WORLD_RANK" in os.environ:
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10638"
torch.distributed.init_process_group("mpi")
_, filename = tempfile.mkstemp()
torch.distributed.init_process_group("mpi", init_method=f"file://{filename}")
world_size = torch.distributed.get_world_size()
destroy_model_parallel()
initialize_model_parallel(1, world_size)
torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
if world_size in world_sizes:
......
......@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from".
force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "dataclasses", "datasets", "helpers", "models", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
known_third_party = ["benchmark_dataset", "datasets", "helpers", "models", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
......@@ -3,40 +3,34 @@
from typing import Union, Callable, Optional, Any
from torch.futures import Future
class RRef:
...
class WorkerInfo:
...
class RRef: ...
class WorkerInfo: ...
class BackendType:
TENSORPIPE: Any
PROCESS_GROUP: Any
def TensorPipeRpcBackendOptions(init_method: str) -> Any : ...
def TensorPipeRpcBackendOptions(init_method: str) -> Any: ...
def ProcessGroupRpcBackendOptions(init_method: str) -> Any: ...
def rpc_async(
to: Union[str, WorkerInfo],
func: Callable,
args: Optional[tuple] = None,
kwargs: Optional[dict] = None,
timeout=-1.0,
) -> Future:
...
) -> Future: ...
def rpc_sync(
to: Union[str, WorkerInfo],
func: Callable,
args: Optional[tuple] = None,
kwargs: Optional[dict] = None,
timeout=-1.0,
) -> None:
...
def init_rpc(name: str, backend: Optional[Any] = None, rank:int = -1, world_size: Optional[int] = None, rpc_backend_options: Optional[Any] = None) -> None: ...
def shutdown() -> None: ...
) -> None: ...
def init_rpc(
name: str,
backend: Optional[Any] = None,
rank: int = -1,
world_size: Optional[int] = None,
rpc_backend_options: Optional[Any] = None,
) -> None: ...
def shutdown(graceful: Optional[bool] = True) -> None: ...
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Any, Callable, Optional, Tuple
from torch import Tensor
def spawn(
fn: Callable[[Any], Any],
args: Tuple[Optional[Any], ...] = (),
nprocs: int = 1,
join: bool = True,
daemon: bool = False,
start_method: str = "spawn",
): ...
......@@ -20,9 +20,11 @@ if torch.cuda.is_available():
devices = ["cpu", "cuda"]
else:
devices = ["cpu"]
URL = "file://" + tempfile.mkstemp()[1]
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501" # torch 1.5 compatibility
if "OMPI_COMM_WORLD_SIZE" in os.environ:
dist.init_process_group(backend=dist.Backend.MPI, init_method=URL)
......
......@@ -27,6 +27,15 @@ from fairscale.utils.testing import skip_if_no_cuda, skip_if_single_gpu
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
try:
from torch.distributed import broadcast_object_list # noqa
_torch_broadcast_object = True
except ImportError:
from fairscale.optim.utils import broadcast_object # noqa
_torch_broadcast_object = False
def dist_init(rank, world_size, tempfile_name, backend=BACKEND):
url = "file://" + tempfile_name
......@@ -394,10 +403,16 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
optimizer_state_dict = {}
optim_state = [optimizer_state_dict]
dist.broadcast_object_list(optim_state, src=reference_rank, group=dist.group.WORLD)
if _torch_broadcast_object:
dist.broadcast_object_list(optim_state, src=reference_rank, group=dist.group.WORLD)
optimizer_state_dict = optim_state[0]
else:
optimizer_state_dict = optim.utils.broadcast_object(
optimizer_state_dict, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device
)
# Load the optimizer state dict
optimizer.load_state_dict(optim_state[0])
optimizer.load_state_dict(optimizer_state_dict)
dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment