Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
195d62f1
Unverified
Commit
195d62f1
authored
Mar 19, 2021
by
msbaines
Committed by
GitHub
Mar 19, 2021
Browse files
[test] use workaround to enable rpc tests when cuda not available (#541)
parent
84e0de84
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
42 deletions
+25
-42
fairscale/utils/testing.py
fairscale/utils/testing.py
+6
-1
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
...s/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
+1
-9
tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_multiprocess_pipe.py
+18
-25
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+0
-7
No files found.
fairscale/utils/testing.py
View file @
195d62f1
...
@@ -149,7 +149,12 @@ def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "")
...
@@ -149,7 +149,12 @@ def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "")
tp_options
=
{
"init_method"
:
url_rpc
}
tp_options
=
{
"init_method"
:
url_rpc
}
# Workaround for bug in torch v1.8.0. Should be fixed in v1.8.1
# Workaround for bug in torch v1.8.0. Should be fixed in v1.8.1
if
torch_version
()
==
(
1
,
8
,
0
):
if
torch_version
()
==
(
1
,
8
,
0
):
tp_options
[
"_transports"
]
=
[
"uv"
]
# type: ignore
if
torch
.
cuda
.
is_available
():
# Workaround for https://github.com/pytorch/pytorch/issues/53844
tp_options
[
"_transports"
]
=
[
"ibv"
,
"uv"
]
# type: ignore
else
:
# Workaround for https://github.com/pytorch/pytorch/issues/54266
tp_options
[
"_channels"
]
=
[
"mpt_uv"
,
"basic"
,
"cuda_ipc"
,
"cuda_gdr"
,
"cuda_xth"
,
"cuda_basic"
]
# type: ignore
rpc
.
init_rpc
(
rpc
.
init_rpc
(
f
"Test
{
rank
}
"
,
f
"Test
{
rank
}
"
,
...
...
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
View file @
195d62f1
...
@@ -17,21 +17,13 @@
...
@@ -17,21 +17,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.optim.optimizer
import
Optimizer
from
torch.optim.optimizer
import
Optimizer
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data
import
DataLoader
,
Dataset
from
fairscale.experimental.nn.ampnet_pipe.pipe
import
AMPnetPipe
from
fairscale.experimental.nn.ampnet_pipe.pipe
import
AMPnetPipe
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
,
torch_version
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
# Current on CI, there appears to be a bug with torch 1.8
# See:
# https://app.circleci.com/pipelines/github/facebookresearch/fairscale/1892/workflows/8f658bf4-8052-4084-bb3e-4cc2c445c8aa/jobs/10080/parallel-runs/0/steps/0-112
# So we skip this file in that case until it is fixed.
if
torch_version
()
>=
(
1
,
8
,
0
):
pytestmark
=
pytest
.
mark
.
skip
class
MySGD
(
Optimizer
):
class
MySGD
(
Optimizer
):
...
...
tests/experimental/nn/test_multiprocess_pipe.py
View file @
195d62f1
...
@@ -30,36 +30,29 @@ if torch.cuda.is_available():
...
@@ -30,36 +30,29 @@ if torch.cuda.is_available():
else
:
else
:
DEVICES
=
[
CPU_DEVICES
]
DEVICES
=
[
CPU_DEVICES
]
# cuda test is because of https://github.com/pytorch/pytorch/issues/54266
pytestmark
=
pytest
.
mark
.
skipif
(
torch_version
()
<
(
1
,
8
,
0
),
reason
=
"requires torch version >= 1.8.0"
)
pytestmark
=
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
torch_version
()
<
(
1
,
8
,
0
),
reason
=
"requires torch version >= 1.8.0 and cuda"
)
def
rpc_worker
(
rank
,
world_size
,
init_file
,
func
,
*
args
):
def
rpc_worker
(
rank
,
world_size
,
init_file
,
func
,
*
args
):
# Workaround for https://github.com/pytorch/pytorch/issues/54266
if
torch_version
()
==
(
1
,
8
,
0
):
if
not
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
options
=
rpc
.
ProcessGroupRpcBackendOptions
(
init_method
=
"file://"
+
init_file
)
# Workaround for https://github.com/pytorch/pytorch/issues/53844
rpc
.
init_rpc
(
"worker"
+
str
(
rank
),
rank
=
rank
,
world_size
=
world_size
,
backend
=
rpc
.
BackendType
.
PROCESS_GROUP
,
rpc_backend_options
=
options
,
)
else
:
# Workaround for https://github.com/pytorch/pytorch/issues/53844
if
torch_version
()
==
(
1
,
8
,
0
):
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
,
_transports
=
[
"ibv"
,
"uv"
])
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
,
_transports
=
[
"ibv"
,
"uv"
])
else
:
else
:
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
)
# Workaround for https://github.com/pytorch/pytorch/issues/54266
rpc
.
init_rpc
(
options
=
rpc
.
TensorPipeRpcBackendOptions
(
"worker"
+
str
(
rank
),
init_method
=
"file://"
+
init_file
,
rank
=
rank
,
_channels
=
[
"mpt_uv"
,
"basic"
,
"cuda_ipc"
,
"cuda_gdr"
,
"cuda_xth"
,
"cuda_basic"
],
world_size
=
world_size
,
)
backend
=
rpc
.
BackendType
.
TENSORPIPE
,
else
:
rpc_backend_options
=
options
,
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
)
)
rpc
.
init_rpc
(
"worker"
+
str
(
rank
),
rank
=
rank
,
world_size
=
world_size
,
backend
=
rpc
.
BackendType
.
TENSORPIPE
,
rpc_backend_options
=
options
,
)
if
rank
==
0
:
if
rank
==
0
:
func
(
*
args
)
func
(
*
args
)
rpc
.
shutdown
()
rpc
.
shutdown
()
...
...
tests/nn/pipe_process/test_pipe.py
View file @
195d62f1
...
@@ -34,13 +34,6 @@ from fairscale.nn.model_parallel.initialize import (
...
@@ -34,13 +34,6 @@ from fairscale.nn.model_parallel.initialize import (
from
fairscale.nn.pipe
import
AsyncPipe
,
LazyModule
,
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
,
LazyModule
,
MultiProcessPipe
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
,
torch_version
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
,
torch_version
# Current on CI, there appears to be a bug with torch 1.8
# See:
# https://app.circleci.com/pipelines/github/facebookresearch/fairscale/1892/workflows/8f658bf4-8052-4084-bb3e-4cc2c445c8aa/jobs/10080/parallel-runs/0/steps/0-112
# So we skip this file in that case until it is fixed.
if
torch_version
()
>=
(
1
,
8
,
0
):
pytestmark
=
pytest
.
mark
.
skip
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment