Unverified Commit 2d954203 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore][ci] restore 1.5 & 1.6 tests and compatibility (#306)

* tentatively fixing the cpu version of circleci jobs, now pipe tests are the last ones standing
* fixing oss backcompat, trying to fix rpc in old pytorch also
* fixing the file based init in torch 1.5
parent 6219b57b
...@@ -43,8 +43,8 @@ install_dep_15: &install_dep_15 ...@@ -43,8 +43,8 @@ install_dep_15: &install_dep_15
name: Install Dependencies name: Install Dependencies
command: | command: |
sudo apt-get install -y libopenmpi-dev sudo apt-get install -y libopenmpi-dev
pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
python -m torch.utils.collect_env python -m torch.utils.collect_env
...@@ -53,37 +53,34 @@ install_dep_16: &install_dep_16 ...@@ -53,37 +53,34 @@ install_dep_16: &install_dep_16
name: Install Dependencies name: Install Dependencies
command: | command: |
sudo apt-get install -y libopenmpi-dev sudo apt-get install -y libopenmpi-dev
pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
python -m torch.utils.collect_env python -m torch.utils.collect_env
install_dep_17_cpu: &install_dep_17_cpu
install_dep_17: &install_dep_17
- run: - run:
name: Install Dependencies name: Install Dependencies
command: | command: |
sudo apt-get install -y libopenmpi-dev sudo apt-get install -y libopenmpi-dev
pip install --progress-bar off torch==1.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off torch==1.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
python -m torch.utils.collect_env python -m torch.utils.collect_env
install_dep_17_gpu: &install_dep_17_gpu install_dep_17_cpu: &install_dep_17_cpu
# FIXME: need to be removed when properly handling torch 1.7.1
# short term fix is to override the default pip installed torch
- run: - run:
name: Install Dependencies name: Install Dependencies
command: | command: |
sudo apt-get install -y libopenmpi-dev sudo apt-get install -y libopenmpi-dev
pip install --progress-bar off -r requirements-test.txt pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off torch==1.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
python -m torch.utils.collect_env python -m torch.utils.collect_env
install_repo_cpu: &install_repo_cpu install_repo_cpu: &install_repo_cpu
- run: - run:
name: Install Repository name: Install Repository
...@@ -291,7 +288,7 @@ jobs: ...@@ -291,7 +288,7 @@ jobs:
keys: keys:
- cache-key-gpu17-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-gpu17-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_17_gpu - <<: *install_dep_17
- save_cache: - save_cache:
paths: paths:
...@@ -328,7 +325,7 @@ jobs: ...@@ -328,7 +325,7 @@ jobs:
keys: keys:
- cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_17_gpu - <<: *install_dep_17
- save_cache: - save_cache:
paths: paths:
......
...@@ -31,7 +31,6 @@ repos: ...@@ -31,7 +31,6 @@ repos:
rev: v2.1.0 rev: v2.1.0
hooks: hooks:
- id: seed-isort-config - id: seed-isort-config
language_version: python3.6
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.6.4 rev: 5.6.4
......
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from threading import Event from threading import Event
from typing import Dict, Iterable, List, Optional, Tuple from typing import Dict, Iterable, List, Optional, Tuple
from dataclasses import dataclass
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
......
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from abc import ABC from abc import ABC
from dataclasses import dataclass
from queue import Empty as QueueEmpty from queue import Empty as QueueEmpty
from queue import Queue from queue import Queue
from typing import Dict, List, Optional from typing import Dict, List, Optional
from dataclasses import dataclass
import torch import torch
from fairscale.nn.model_parallel import get_pipeline_parallel_group from fairscale.nn.model_parallel import get_pipeline_parallel_group
......
...@@ -19,12 +19,12 @@ ...@@ -19,12 +19,12 @@
"""The Pipe interface.""" """The Pipe interface."""
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, field
import itertools import itertools
import threading import threading
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union, cast
import warnings import warnings
from dataclasses import dataclass, field
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
import torch.autograd import torch.autograd
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
from dataclasses import dataclass
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
......
...@@ -554,7 +554,6 @@ class AdaScale(Optimizer): ...@@ -554,7 +554,6 @@ class AdaScale(Optimizer):
`set_scale` needs to be called to update the scale as well. `set_scale` needs to be called to update the scale as well.
TODO (min): need a way of determine how much to increase the step size? TODO (min): need a way of determine how much to increase the step size?
TODO (min): have both `set_scale` and `set_num_gradients_to_accumulate` TODO (min): have both `set_scale` and `set_num_gradients_to_accumulate`
is hard to use and easy to make mistake. I think it is better is hard to use and easy to make mistake. I think it is better
to specific a specify a `base_scale`. But more discussion is to specific a specify a `base_scale`. But more discussion is
......
...@@ -25,6 +25,15 @@ if TYPE_CHECKING: # pragma: no cover ...@@ -25,6 +25,15 @@ if TYPE_CHECKING: # pragma: no cover
else: else:
_params_t = Any _params_t = Any
try:
from torch.distributed import broadcast_object_list # noqa
_torch_broadcast_object = True
except ImportError:
from .utils import broadcast_object
_torch_broadcast_object = False
class OSS(Optimizer): class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` """Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
...@@ -339,12 +348,27 @@ class OSS(Optimizer): ...@@ -339,12 +348,27 @@ class OSS(Optimizer):
logging.debug( logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank, "Sending the sharded optimizer state to the reference replica from rank %s", rank,
) )
if _torch_broadcast_object:
# torch native object broadcast
dist.broadcast_object_list([local_cpu_state], src=self.global_rank, group=self.group) dist.broadcast_object_list([local_cpu_state], src=self.global_rank, group=self.group)
else:
# legacy compatibility for old torch versions
broadcast_object(
self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device
)
else: else:
global_rank = self.get_global_rank(self.group, rank) global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather # Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
if _torch_broadcast_object:
dist.broadcast_object_list([0], src=global_rank, group=self.group) dist.broadcast_object_list([0], src=global_rank, group=self.group)
else:
broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
src_rank=global_rank,
group=self.group,
dist_device=self._device,
)
def _collect_sharded_states(self) -> List[Dict[str, Any]]: def _collect_sharded_states(self) -> List[Dict[str, Any]]:
"""Collect all the state shards, in CPU memory.""" """Collect all the state shards, in CPU memory."""
...@@ -358,16 +382,35 @@ class OSS(Optimizer): ...@@ -358,16 +382,35 @@ class OSS(Optimizer):
) )
# Sync with other replicas # Sync with other replicas
if _torch_broadcast_object:
# torch native object broadcast
dist.broadcast_object_list([0], src=self.global_rank, group=self.group) dist.broadcast_object_list([0], src=self.global_rank, group=self.group)
else:
# legacy compatibility for old torch versions
broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
src_rank=self.global_rank,
group=self.group,
dist_device=self._device,
)
else: else:
# Fetch the optim state from the other replicas # Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank) global_rank = self.get_global_rank(self.group, rank)
replica_state = [0] if _torch_broadcast_object:
dist.broadcast_object_list(replica_state, src=global_rank, group=self.group) replica_state_l = [0]
dist.broadcast_object_list(replica_state_l, src=global_rank, group=self.group)
replica_state = replica_state_l[0]
else:
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
src_rank=global_rank,
group=self.group,
dist_device=self._device,
)
all_states.append( all_states.append(
recursive_copy_to_device(replica_state[0], non_blocking=True, device=torch.device("cpu")) recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
) )
logging.debug("State from rank %s received", rank) logging.debug("State from rank %s received", rank)
......
...@@ -3,10 +3,12 @@ ...@@ -3,10 +3,12 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import io
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Dict, Optional
import torch import torch
from torch._six import container_abcs from torch._six import container_abcs
import torch.distributed as dist
class Workhandle: class Workhandle:
...@@ -44,3 +46,33 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic ...@@ -44,3 +46,33 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
return device_val return device_val
return value return value
# backward compatibility - this is needed for torch 1.5 which does not expose this functionality
# FIXME: to be dropped alongside torch1.5 support, when time comes
def broadcast_object(
obj: Any, src_rank: int, group: object = dist.group.WORLD, dist_device: torch.device = torch.device("cpu")
) -> Any:
"""
Either broadcast from master to the fleet (default),
or use the src setting as the original rank.
"""
if dist.get_rank() == src_rank:
# Emit data
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
length_tensor = torch.LongTensor([len(data)]).to(dist_device)
data_send_tensor = torch.ByteTensor(data).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
else:
# Fetch from the source
length_tensor = torch.LongTensor([0]).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=dist_device)
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=dist_device)
return obj
...@@ -100,21 +100,26 @@ def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "") ...@@ -100,21 +100,26 @@ def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "")
.. warning: This limits the usecase to all ranks being on the same node .. warning: This limits the usecase to all ranks being on the same node
""" """
try:
torch.distributed.rpc.shutdown()
except Exception:
pass
print(f"dist init r={rank}, world={world_size}") print(f"dist init r={rank}, world={world_size}")
os.environ["WORLD_SIZE"] = str(world_size) os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank) os.environ["RANK"] = str(rank)
url = "file://" + filename url = "file://" + filename
url_rpc = "file://" + filename_rpc
if torch_version() >= (1, 6, 0): if torch_version() >= (1, 6, 0):
backend = "nccl" if torch.cuda.is_available() else "gloo" backend = "nccl" if torch.cuda.is_available() else "gloo"
if backend == "nccl" and torch.cuda.device_count() < world_size: if backend == "nccl" and torch.cuda.device_count() < world_size:
logging.warning("Requested world size cannot be reached on this machine, not enough GPUs") logging.warning("Requested world size cannot be reached on this machine, not enough GPUs")
return False return False
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=url) torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=url)
url_rpc = "file://" + filename_rpc
rpc.init_rpc( rpc.init_rpc(
f"Test{rank}", f"Test{rank}",
rank=rank, rank=rank,
...@@ -125,7 +130,13 @@ def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "") ...@@ -125,7 +130,13 @@ def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "")
else: else:
if world_size > 1: if world_size > 1:
rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size) # TensorPipe is not available in Torch 1.5
rpc.init_rpc(
name=f"Test{rank}",
rank=rank,
world_size=world_size,
rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(init_method=url_rpc),
)
elif torch.cuda.is_available(): elif torch.cuda.is_available():
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, init_method=url) torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, init_method=url)
else: else:
...@@ -153,7 +164,7 @@ def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_ ...@@ -153,7 +164,7 @@ def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_
_, filename_rpc = tempfile.mkstemp() _, filename_rpc = tempfile.mkstemp()
# (lefaudeux) Let mp handle the process joining, join=False and handling context has been unstable in the past # (lefaudeux) Let mp handle the process joining, join=False and handling context has been unstable in the past
mp.spawn(test_func, args=(world_size, filename, filename_rpc, *args), nprocs=world_size, join=True) # type: ignore mp.spawn(test_func, args=(world_size, filename, filename_rpc, *args), nprocs=world_size, join=True)
def worker_process( def worker_process(
...@@ -163,6 +174,7 @@ def worker_process( ...@@ -163,6 +174,7 @@ def worker_process(
if not dist_init(rank, world_size, filename, filename_rpc): if not dist_init(rank, world_size, filename, filename_rpc):
logging.warning("failed initializing torch distributed") logging.warning("failed initializing torch distributed")
teardown()
return return
kwargs = {} kwargs = {}
...@@ -195,7 +207,8 @@ def teardown() -> None: ...@@ -195,7 +207,8 @@ def teardown() -> None:
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
try: try:
torch.distributed.rpc.shutdown() # torch 1.5 hangs on shutdown if waiting for all processes
torch.distributed.rpc.shutdown(graceful=False)
except Exception: except Exception:
pass pass
...@@ -230,10 +243,10 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable: ...@@ -230,10 +243,10 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable:
if "OMPI_COMM_WORLD_RANK" in os.environ: if "OMPI_COMM_WORLD_RANK" in os.environ:
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
os.environ["MASTER_ADDR"] = "localhost" _, filename = tempfile.mkstemp()
os.environ["MASTER_PORT"] = "10638" torch.distributed.init_process_group("mpi", init_method=f"file://{filename}")
torch.distributed.init_process_group("mpi")
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
destroy_model_parallel()
initialize_model_parallel(1, world_size) initialize_model_parallel(1, world_size)
torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count()) torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
if world_size in world_sizes: if world_size in world_sizes:
......
...@@ -28,4 +28,4 @@ use_parentheses = true ...@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"] skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from". # Don't split "import" and "from".
force_sort_within_sections = true force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "dataclasses", "datasets", "helpers", "models", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"] known_third_party = ["benchmark_dataset", "datasets", "helpers", "models", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
...@@ -3,40 +3,34 @@ ...@@ -3,40 +3,34 @@
from typing import Union, Callable, Optional, Any from typing import Union, Callable, Optional, Any
from torch.futures import Future from torch.futures import Future
class RRef: ...
class RRef: class WorkerInfo: ...
...
class WorkerInfo:
...
class BackendType: class BackendType:
TENSORPIPE: Any TENSORPIPE: Any
PROCESS_GROUP: Any PROCESS_GROUP: Any
def TensorPipeRpcBackendOptions(init_method: str) -> Any : ... def TensorPipeRpcBackendOptions(init_method: str) -> Any: ...
def ProcessGroupRpcBackendOptions(init_method: str) -> Any: ...
def rpc_async( def rpc_async(
to: Union[str, WorkerInfo], to: Union[str, WorkerInfo],
func: Callable, func: Callable,
args: Optional[tuple] = None, args: Optional[tuple] = None,
kwargs: Optional[dict] = None, kwargs: Optional[dict] = None,
timeout=-1.0, timeout=-1.0,
) -> Future: ) -> Future: ...
...
def rpc_sync( def rpc_sync(
to: Union[str, WorkerInfo], to: Union[str, WorkerInfo],
func: Callable, func: Callable,
args: Optional[tuple] = None, args: Optional[tuple] = None,
kwargs: Optional[dict] = None, kwargs: Optional[dict] = None,
timeout=-1.0, timeout=-1.0,
) -> None: ) -> None: ...
... def init_rpc(
name: str,
backend: Optional[Any] = None,
def init_rpc(name: str, backend: Optional[Any] = None, rank:int = -1, world_size: Optional[int] = None, rpc_backend_options: Optional[Any] = None) -> None: ... rank: int = -1,
world_size: Optional[int] = None,
def shutdown() -> None: ... rpc_backend_options: Optional[Any] = None,
) -> None: ...
def shutdown(graceful: Optional[bool] = True) -> None: ...
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Any, Callable, Optional, Tuple
from torch import Tensor
def spawn(
fn: Callable[[Any], Any],
args: Tuple[Optional[Any], ...] = (),
nprocs: int = 1,
join: bool = True,
daemon: bool = False,
start_method: str = "spawn",
): ...
...@@ -20,9 +20,11 @@ if torch.cuda.is_available(): ...@@ -20,9 +20,11 @@ if torch.cuda.is_available():
devices = ["cpu", "cuda"] devices = ["cpu", "cuda"]
else: else:
devices = ["cpu"] devices = ["cpu"]
URL = "file://" + tempfile.mkstemp()[1] URL = "file://" + tempfile.mkstemp()[1]
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501" # torch 1.5 compatibility
if "OMPI_COMM_WORLD_SIZE" in os.environ: if "OMPI_COMM_WORLD_SIZE" in os.environ:
dist.init_process_group(backend=dist.Backend.MPI, init_method=URL) dist.init_process_group(backend=dist.Backend.MPI, init_method=URL)
......
...@@ -27,6 +27,15 @@ from fairscale.utils.testing import skip_if_no_cuda, skip_if_single_gpu ...@@ -27,6 +27,15 @@ from fairscale.utils.testing import skip_if_no_cuda, skip_if_single_gpu
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu") DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
try:
from torch.distributed import broadcast_object_list # noqa
_torch_broadcast_object = True
except ImportError:
from fairscale.optim.utils import broadcast_object # noqa
_torch_broadcast_object = False
def dist_init(rank, world_size, tempfile_name, backend=BACKEND): def dist_init(rank, world_size, tempfile_name, backend=BACKEND):
url = "file://" + tempfile_name url = "file://" + tempfile_name
...@@ -394,10 +403,16 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name): ...@@ -394,10 +403,16 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
optimizer_state_dict = {} optimizer_state_dict = {}
optim_state = [optimizer_state_dict] optim_state = [optimizer_state_dict]
if _torch_broadcast_object:
dist.broadcast_object_list(optim_state, src=reference_rank, group=dist.group.WORLD) dist.broadcast_object_list(optim_state, src=reference_rank, group=dist.group.WORLD)
optimizer_state_dict = optim_state[0]
else:
optimizer_state_dict = optim.utils.broadcast_object(
optimizer_state_dict, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device
)
# Load the optimizer state dict # Load the optimizer state dict
optimizer.load_state_dict(optim_state[0]) optimizer.load_state_dict(optimizer_state_dict)
dist.destroy_process_group() dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment