Unverified Commit 195d62f1 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[test] use workaround to enable rpc tests when cuda not available (#541)

parent 84e0de84
......@@ -149,7 +149,12 @@ def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "")
tp_options = {"init_method": url_rpc}
# Workaround for bug in torch v1.8.0. Should be fixed in v1.8.1
if torch_version() == (1, 8, 0):
tp_options["_transports"] = ["uv"] # type: ignore
if torch.cuda.is_available():
# Workaround for https://github.com/pytorch/pytorch/issues/53844
tp_options["_transports"] = ["ibv", "uv"] # type: ignore
else:
# Workaround for https://github.com/pytorch/pytorch/issues/54266
tp_options["_channels"] = ["mpt_uv", "basic", "cuda_ipc", "cuda_gdr", "cuda_xth", "cuda_basic"] # type: ignore
rpc.init_rpc(
f"Test{rank}",
......
......@@ -17,21 +17,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from torch import nn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset
from fairscale.experimental.nn.ampnet_pipe.pipe import AMPnetPipe
from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version
# Current on CI, there appears to be a bug with torch 1.8
# See:
# https://app.circleci.com/pipelines/github/facebookresearch/fairscale/1892/workflows/8f658bf4-8052-4084-bb3e-4cc2c445c8aa/jobs/10080/parallel-runs/0/steps/0-112
# So we skip this file in that case until it is fixed.
if torch_version() >= (1, 8, 0):
pytestmark = pytest.mark.skip
from fairscale.utils.testing import get_worker_map, torch_spawn
class MySGD(Optimizer):
......
......@@ -30,36 +30,29 @@ if torch.cuda.is_available():
else:
DEVICES = [CPU_DEVICES]
# cuda test is because of https://github.com/pytorch/pytorch/issues/54266
pytestmark = pytest.mark.skipif(
not torch.cuda.is_available() or torch_version() < (1, 8, 0), reason="requires torch version >= 1.8.0 and cuda"
)
pytestmark = pytest.mark.skipif(torch_version() < (1, 8, 0), reason="requires torch version >= 1.8.0")
def rpc_worker(rank, world_size, init_file, func, *args):
# Workaround for https://github.com/pytorch/pytorch/issues/54266
if not torch.cuda.is_available():
options = rpc.ProcessGroupRpcBackendOptions(init_method="file://" + init_file)
rpc.init_rpc(
"worker" + str(rank),
rank=rank,
world_size=world_size,
backend=rpc.BackendType.PROCESS_GROUP,
rpc_backend_options=options,
)
else:
# Workaround for https://github.com/pytorch/pytorch/issues/53844
if torch_version() == (1, 8, 0):
if torch_version() == (1, 8, 0):
if torch.cuda.is_available():
# Workaround for https://github.com/pytorch/pytorch/issues/53844
options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file, _transports=["ibv", "uv"])
else:
options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file)
rpc.init_rpc(
"worker" + str(rank),
rank=rank,
world_size=world_size,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=options,
)
# Workaround for https://github.com/pytorch/pytorch/issues/54266
options = rpc.TensorPipeRpcBackendOptions(
init_method="file://" + init_file,
_channels=["mpt_uv", "basic", "cuda_ipc", "cuda_gdr", "cuda_xth", "cuda_basic"],
)
else:
options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file)
rpc.init_rpc(
"worker" + str(rank),
rank=rank,
world_size=world_size,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=options,
)
if rank == 0:
func(*args)
rpc.shutdown()
......
......@@ -34,13 +34,6 @@ from fairscale.nn.model_parallel.initialize import (
from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version
# Current on CI, there appears to be a bug with torch 1.8
# See:
# https://app.circleci.com/pipelines/github/facebookresearch/fairscale/1892/workflows/8f658bf4-8052-4084-bb3e-4cc2c445c8aa/jobs/10080/parallel-runs/0/steps/0-112
# So we skip this file in that case until it is fixed.
if torch_version() >= (1, 8, 0):
pytestmark = pytest.mark.skip
@torch_spawn([2])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment