Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
e83da060
Unverified
Commit
e83da060
authored
Dec 01, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Dec 01, 2020
Browse files
[chore] Refactor unit testing, shared utils (#218)
parent
1db8bbda
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
117 additions
and
70 deletions
+117
-70
benchmarks/pipe.py
benchmarks/pipe.py
+1
-1
fairscale/utils/testing.py
fairscale/utils/testing.py
+48
-17
pyproject.toml
pyproject.toml
+1
-1
stubs/torch/cuda/__init__.pyi
stubs/torch/cuda/__init__.pyi
+7
-0
stubs/torch/distributed/__init__.pyi
stubs/torch/distributed/__init__.pyi
+3
-1
stubs/torch/distributed/rpc/__init__.pyi
stubs/torch/distributed/rpc/__init__.pyi
+11
-1
tests/nn/model_parallel/test_cross_entropy.py
tests/nn/model_parallel/test_cross_entropy.py
+1
-1
tests/nn/model_parallel/test_initialize.py
tests/nn/model_parallel/test_initialize.py
+1
-1
tests/nn/model_parallel/test_layers.py
tests/nn/model_parallel/test_layers.py
+1
-7
tests/nn/model_parallel/test_random.py
tests/nn/model_parallel/test_random.py
+1
-1
tests/nn/pipe/conftest.py
tests/nn/pipe/conftest.py
+29
-7
tests/nn/pipe_process/conftest.py
tests/nn/pipe_process/conftest.py
+6
-5
tests/nn/pipe_process/skip/test_gpipe.py
tests/nn/pipe_process/skip/test_gpipe.py
+1
-1
tests/nn/pipe_process/skip/test_leak.py
tests/nn/pipe_process/skip/test_leak.py
+1
-1
tests/nn/pipe_process/test_bugs.py
tests/nn/pipe_process/test_bugs.py
+1
-1
tests/nn/pipe_process/test_inplace.py
tests/nn/pipe_process/test_inplace.py
+1
-1
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+1
-21
tests/nn/pipe_process/test_rpc.py
tests/nn/pipe_process/test_rpc.py
+1
-1
tests/nn/pipe_process/test_transparency.py
tests/nn/pipe_process/test_transparency.py
+1
-1
No files found.
benchmarks/pipe.py
View file @
e83da060
...
...
@@ -23,7 +23,7 @@ from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_
from
fairscale.nn.pipe
import
LazyModule
,
pipe
from
fairscale.optim
import
GradScaler
from
fairscale.optim.oss
import
OSS
from
tests.nn.model_parallel.commons
import
dist_init
,
get_worker_map
from
fairscale.utils.testing
import
dist_init
,
get_worker_map
try
:
from
fairscale.optim
import
Adam
# type: ignore
...
...
tests/nn/model_parallel/commons
.py
→
fairscale/utils/testing
.py
View file @
e83da060
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
...
...
@@ -11,7 +9,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
@@ -19,14 +17,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# We're not responsible for pytest decorators
# mypy: disallow_untyped_decorators = False
"""
Collection of some testing utilities for the Fairscale library. Please complement as you see fit, but refrain from ad-hoc test utils
within the different feature sets and relative imports.
"""
import
functools
import
inspect
import
logging
import
multiprocessing
import
os
import
random
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
numpy
from
packaging
import
version
import
pytest
import
torch
import
torch.distributed
as
dist
...
...
@@ -38,11 +45,11 @@ from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
class
IdentityLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
size
,
scale
=
1.0
)
:
def
__init__
(
self
,
size
:
int
,
scale
:
float
=
1.0
)
->
None
:
super
(
IdentityLayer
,
self
).
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
scale
*
torch
.
randn
(
size
))
def
forward
(
self
)
:
def
forward
(
self
,
*
_
:
Any
,
**
__
:
Any
)
->
Any
:
return
self
.
weight
...
...
@@ -54,7 +61,26 @@ def set_random_seed(seed: int) -> None:
model_parallel_cuda_manual_seed
(
seed
)
def
dist_init
(
rank
,
world_size
,
hostname
=
None
):
def
torch_version
()
->
Tuple
[
int
,
...]:
numbering
=
torch
.
__version__
.
split
(
"."
)
assert
len
(
numbering
)
==
3
# Catch torch version if run against internal pre-releases, like `1.8.0a0fb`,
if
not
numbering
[
2
].
isnumeric
():
# Two options here:
# - either skip this version (minor number check is not relevant)
# - or check that our codebase is not broken by this ongoing development.
# Assuming that we're interested in the second usecase more than the first,
# return the pre-release or dev numbering
logging
.
warning
(
f
"Pytorch pre-relase version
{
torch
.
__version__
}
- assuming intent to test it"
)
numbering
[
2
]
=
"0"
return
tuple
(
int
(
n
)
for
n
in
numbering
)
def
dist_init
(
rank
:
int
,
world_size
:
int
,
hostname
:
Optional
[
str
]
=
None
)
->
None
:
if
hostname
is
None
:
hostname
=
"localhost"
print
(
f
"dist init r=
{
rank
}
, world=
{
world_size
}
, host=
{
hostname
}
"
)
...
...
@@ -63,7 +89,7 @@ def dist_init(rank, world_size, hostname=None):
os
.
environ
[
"WORLD_SIZE"
]
=
str
(
world_size
)
os
.
environ
[
"RANK"
]
=
str
(
rank
)
if
version
.
parse
(
torch
.
_
_version
__
).
release
>=
(
1
,
6
,
0
):
if
torch_version
()
>=
(
1
,
6
,
0
):
init_method
=
f
"tcp://
{
os
.
environ
[
'MASTER_ADDR'
]
}
:
{
os
.
environ
[
'MASTER_PORT'
]
}
"
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
,
init_method
=
init_method
)
...
...
@@ -77,6 +103,7 @@ def dist_init(rank, world_size, hostname=None):
backend
=
rpc
.
BackendType
.
TENSORPIPE
,
rpc_backend_options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
init_method
),
)
else
:
if
world_size
>
1
:
rpc
.
init_rpc
(
f
"Test
{
rank
}
"
,
rank
=
rank
,
world_size
=
world_size
)
...
...
@@ -87,21 +114,21 @@ def dist_init(rank, world_size, hostname=None):
torch
.
cuda
.
set_device
(
rank
%
torch
.
cuda
.
device_count
())
def
get_worker_map
():
def
get_worker_map
()
->
Dict
[
Any
,
Any
]
:
return
{
rank
:
f
"Test
{
rank
}
"
for
rank
in
range
(
dist
.
get_world_size
())}
def
get_world_sizes
():
def
get_world_sizes
()
->
List
[
int
]
:
limit
=
torch
.
cuda
.
device_count
()
return
[
x
for
x
in
[
1
,
2
,
4
,
8
]
if
x
<=
limit
]
def
spawn_for_all_world_sizes
(
test_func
,
world_sizes
=
get_world_sizes
(),
args
=
[])
:
def
spawn_for_all_world_sizes
(
test_func
:
Callable
,
world_sizes
:
List
[
int
]
=
get_world_sizes
(),
args
:
Any
=
[])
->
None
:
for
world_size
in
world_sizes
:
mp
.
spawn
(
test_func
,
args
=
(
world_size
,
*
args
),
nprocs
=
world_size
,
join
=
True
)
mp
.
spawn
(
test_func
,
args
=
(
world_size
,
*
args
),
nprocs
=
world_size
,
join
=
True
)
# type: ignore
def
worker_process
(
rank
,
world_size
,
func
,
args
,
error_queue
)
:
def
worker_process
(
rank
:
int
,
world_size
:
int
,
func
:
Callable
,
args
:
Any
,
error_queue
:
Any
)
->
None
:
"""Main function for unit tests launced with torch_spawn"""
dist_init
(
rank
,
world_size
)
...
...
@@ -120,11 +147,11 @@ def worker_process(rank, world_size, func, args, error_queue):
raise
e
def
torch_spawn
(
world_sizes
=
None
)
:
def
torch_spawn
(
world_sizes
:
Optional
[
List
[
int
]]
=
None
)
->
Callable
:
if
world_sizes
is
None
:
world_sizes
=
get_world_sizes
()
def
prepare_test
(
func
)
:
def
prepare_test
(
func
:
Callable
)
->
Callable
:
"""Function called with the test function as the argument. Generates a
replacement which serves as the actual test function."""
...
...
@@ -138,8 +165,10 @@ def torch_spawn(world_sizes=None):
)
@
functools
.
wraps
(
func
)
def
replacement
(
*
args
,
**
kwargs
)
:
def
replacement
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
assert
args
==
tuple
()
assert
world_sizes
is
not
None
# mypy crutch
args
=
tuple
(
kwargs
[
p
]
for
p
in
parameters
if
p
!=
"rank"
)
# converting named parameters to positional parameters to pass to `spawn`
...
...
@@ -174,7 +203,9 @@ def torch_spawn(world_sizes=None):
# Register a function with the same name, prefixed with "test_" in the
# calling module, so it will be picked up by pytest
caller_module
=
inspect
.
getmodule
(
inspect
.
currentframe
().
f_back
)
current_frame
=
inspect
.
currentframe
()
assert
current_frame
is
not
None
caller_module
=
inspect
.
getmodule
(
current_frame
.
f_back
)
setattr
(
caller_module
,
f
"test_
{
name
}
"
,
replacement
)
return
func
...
...
pyproject.toml
View file @
e83da060
...
...
@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob
=
[
"build/*"
,
"stubs/*"
]
# Don't split "import" and "from".
force_sort_within_sections
=
true
known_third_party
=
[
"benchmark_dataset"
,
"dataclasses"
,
"numpy"
,
"packaging"
,
"pytest"
,
"recommonmark"
,
"setuptools"
,
"torch"
,
"torch_pg"
,
"torchtext"
,
"torchvision"
]
known_third_party
=
[
"benchmark_dataset"
,
"dataclasses"
,
"numpy"
,
"pytest"
,
"recommonmark"
,
"setuptools"
,
"torch"
,
"torch_pg"
,
"torchtext"
,
"torchvision"
]
stubs/torch/cuda/__init__.pyi
View file @
e83da060
...
...
@@ -7,6 +7,7 @@ from .. import device as _device
def is_available() -> bool: ...
def init() -> None: ...
def _lazy_call(callable) -> None: ...
def _sleep(_:int) -> None : ...
class cudaStatus:
SUCCESS: int
...
...
@@ -64,6 +65,12 @@ class Stream:
def synchronize(self) -> None: ...
def wait_stream(self, stream: Stream) -> None: ...
class Event:
def __new__(cls, enable_timing: bool = False, blocking:bool = False, interprocess: bool = False) -> "Event": ...
def record(self, stream: Optional[Stream] = None) -> None: ...
def synchronize(self) -> None: ...
def elapsed_time(self, end_event: Event) -> int: ...
class stream:
def __init__(self, stream: Optional[Stream] = ...) -> None: ...
def __enter__(self) -> None: ...
...
...
stubs/torch/distributed/__init__.pyi
View file @
e83da060
...
...
@@ -35,7 +35,7 @@ def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optio
def is_initialized() -> bool: ...
def init_process_group(backend: Union[str, Backend], timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ...
def init_process_group(backend: Union[str, Backend],
init_method: Optional[str] = None,
timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ...
def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ...
def all_to_all(output: List[Tensor], input: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ...
...
...
@@ -43,6 +43,8 @@ def all_to_all_single(output: Tensor, input: Tensor, output_split_size: Optional
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def destroy_process_group() -> None: ...
def send(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ...
def isend(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ...
def recv(tensor: Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> int: ...
...
...
stubs/torch/distributed/rpc/__init__.pyi
View file @
e83da060
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Union, Callable, Optional
from typing import Union, Callable, Optional
, Any
from torch.futures import Future
...
...
@@ -11,6 +11,11 @@ class RRef:
class WorkerInfo:
...
class BackendType:
TENSORPIPE: Any
PROCESS_GROUP: Any
def TensorPipeRpcBackendOptions(init_method: str) -> Any : ...
def rpc_async(
to: Union[str, WorkerInfo],
...
...
@@ -30,3 +35,8 @@ def rpc_sync(
timeout=-1.0,
) -> None:
...
def init_rpc(name: str, backend: Optional[Any] = None, rank:int = -1, world_size: Optional[int] = None, rpc_backend_options: Optional[Any] = None) -> None: ...
def shutdown() -> None: ...
tests/nn/model_parallel/test_cross_entropy.py
View file @
e83da060
...
...
@@ -25,7 +25,7 @@ import torch.nn.functional as F
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel.cross_entropy
import
vocab_parallel_cross_entropy
from
fairscale.nn.model_parallel.mappings
import
scatter_to_model_parallel_region
from
tests.nn.model_parallel.commons
import
IdentityLayer
,
dist_init
,
set_random_seed
,
spawn_for_all_world_sizes
from
fairscale.utils.testing
import
IdentityLayer
,
dist_init
,
set_random_seed
,
spawn_for_all_world_sizes
def
torch_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
):
...
...
tests/nn/model_parallel/test_initialize.py
View file @
e83da060
...
...
@@ -22,7 +22,7 @@
import
torch
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
tests.nn.model_parallel.commons
import
dist_init
,
spawn_for_all_world_sizes
from
fairscale.utils.testing
import
dist_init
,
spawn_for_all_world_sizes
def
run_test_initialize_model_parallel
(
rank
,
model_parallel_size
):
...
...
tests/nn/model_parallel/test_layers.py
View file @
e83da060
...
...
@@ -31,13 +31,7 @@ from torch.nn.parameter import Parameter
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
layers
from
fairscale.nn.pipe
import
Pipe
from
tests.nn.model_parallel.commons
import
(
dist_init
,
get_world_sizes
,
set_random_seed
,
spawn_for_all_world_sizes
,
torch_spawn
,
)
from
fairscale.utils.testing
import
dist_init
,
get_world_sizes
,
set_random_seed
,
spawn_for_all_world_sizes
,
torch_spawn
def
run_test_parallel_embedding
(
rank
,
model_parallel_size
):
...
...
tests/nn/model_parallel/test_random.py
View file @
e83da060
...
...
@@ -24,7 +24,7 @@ import torch
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
random
from
fairscale.nn.model_parallel.random
import
get_cuda_rng_tracker
,
model_parallel_cuda_manual_seed
from
tests.nn.model_parallel.commons
import
dist_init
,
spawn_for_all_world_sizes
from
fairscale.utils.testing
import
dist_init
,
spawn_for_all_world_sizes
def
run_test_set_cuda_rng_state
(
rank
,
model_parallel_size
):
...
...
tests/nn/pipe/conftest.py
View file @
e83da060
...
...
@@ -17,17 +17,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
functools
import
os
from
typing
import
Any
,
Callable
import
pytest
import
torch
from
fairscale.nn.model_parallel
import
destroy_model_parallel
@
pytest
.
fixture
(
autouse
=
True
)
def
manual_seed_zero
():
def
manual_seed_zero
()
->
None
:
torch
.
manual_seed
(
0
)
def
cuda_sleep_impl
(
seconds
,
cycles_per_ms
):
torch
.
cuda
.
_sleep
(
int
(
seconds
*
cycles_per_ms
*
1000
))
@
pytest
.
fixture
(
scope
=
"session"
)
def
cuda_sleep
():
def
cuda_sleep
()
->
Callable
:
# Warm-up CUDA.
torch
.
empty
(
1
,
device
=
"cuda"
)
...
...
@@ -40,11 +50,23 @@ def cuda_sleep():
end
.
synchronize
()
cycles_per_ms
=
1000000
/
start
.
elapsed_time
(
end
)
def
cuda_sleep
(
seconds
):
torch
.
cuda
.
_sleep
(
int
(
seconds
*
cycles_per_ms
*
1000
))
return
cuda_sleep
return
functools
.
partial
(
cuda_sleep_impl
,
cycles_per_ms
=
cycles_per_ms
)
def
pytest_report_header
():
def
pytest_report_header
()
->
str
:
return
f
"torch:
{
torch
.
__version__
}
"
def
pytest_runtest_setup
(
item
:
Any
)
->
None
:
print
(
f
"setup mpi function called"
)
def
pytest_runtest_teardown
(
item
:
Any
)
->
None
:
if
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
:
destroy_model_parallel
()
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
destroy_process_group
()
try
:
torch
.
distributed
.
rpc
.
shutdown
()
except
Exception
:
pass
tests/nn/pipe_process/conftest.py
View file @
e83da060
...
...
@@ -19,6 +19,7 @@
import
functools
import
os
from
typing
import
Any
,
Callable
import
pytest
import
torch
...
...
@@ -27,7 +28,7 @@ from fairscale.nn.model_parallel import destroy_model_parallel
@
pytest
.
fixture
(
autouse
=
True
)
def
manual_seed_zero
():
def
manual_seed_zero
()
->
None
:
torch
.
manual_seed
(
0
)
...
...
@@ -36,7 +37,7 @@ def cuda_sleep_impl(seconds, cycles_per_ms):
@
pytest
.
fixture
(
scope
=
"session"
)
def
cuda_sleep
():
def
cuda_sleep
()
->
Callable
:
# Warm-up CUDA.
torch
.
empty
(
1
,
device
=
"cuda"
)
...
...
@@ -52,15 +53,15 @@ def cuda_sleep():
return
functools
.
partial
(
cuda_sleep_impl
,
cycles_per_ms
=
cycles_per_ms
)
def
pytest_report_header
():
def
pytest_report_header
()
->
str
:
return
f
"torch:
{
torch
.
__version__
}
"
def
pytest_runtest_setup
(
item
)
:
def
pytest_runtest_setup
(
item
:
Any
)
->
None
:
print
(
f
"setup mpi function called"
)
def
pytest_runtest_teardown
(
item
)
:
def
pytest_runtest_teardown
(
item
:
Any
)
->
None
:
if
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
:
destroy_model_parallel
()
if
torch
.
distributed
.
is_initialized
():
...
...
tests/nn/pipe_process/skip/test_gpipe.py
View file @
e83da060
...
...
@@ -26,7 +26,7 @@ from torch import nn
from
fairscale.nn.pipe
import
LazyModule
,
Pipe
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip.portal
import
PortalBlue
,
PortalCopy
,
PortalOrange
from
tests.nn.model_parallel.commons
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
3
])
...
...
tests/nn/pipe_process/skip/test_leak.py
View file @
e83da060
...
...
@@ -26,7 +26,7 @@ from torch import nn
from
fairscale.nn.pipe
import
Pipe
,
is_checkpointing
,
is_recomputing
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip.tracker
import
current_skip_tracker
from
tests.nn.model_parallel.commons
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
skippable
(
stash
=
[
"skip"
])
...
...
tests/nn/pipe_process/test_bugs.py
View file @
e83da060
...
...
@@ -23,7 +23,7 @@ from torch import nn
import
torch.nn.functional
as
F
from
fairscale.nn.pipe
import
Pipe
from
tests.nn.model_parallel.commons
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
2
])
...
...
tests/nn/pipe_process/test_inplace.py
View file @
e83da060
...
...
@@ -22,7 +22,7 @@ import torch
from
torch
import
nn
from
fairscale.nn.pipe
import
Pipe
from
tests.nn.model_parallel.commons
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
2
])
...
...
tests/nn/pipe_process/test_pipe.py
View file @
e83da060
...
...
@@ -21,9 +21,7 @@ from collections import OrderedDict
from
copy
import
deepcopy
import
os
import
time
from
typing
import
Tuple
from
packaging
import
version
import
pytest
import
torch
from
torch
import
nn
...
...
@@ -34,7 +32,7 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel
,
)
from
fairscale.nn.pipe
import
LazyModule
,
Pipe
from
tests.nn.model_parallel.commons
import
get_worker_map
,
set_random_seed
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
,
torch_version
@
torch_spawn
([
2
])
...
...
@@ -373,24 +371,6 @@ def checkpoint_eval(pipeline_style):
assert
not
find_grad_fn
(
eval_output
.
grad_fn
,
"RecomputeBackward"
)
def
torch_version
()
->
Tuple
[
int
,
...]:
result
=
version
.
parse
(
torch
.
__version__
).
release
# Catch torch version if run against internal pre-releases, like `1.8.0a0fb`,
# for which version.parse().release will return None (version becomes of LegacyVersion type)
if
result
is
None
:
# Two options here:
# - either skip this version,
# - or check that Pipe is not broken by this ongoing development.
# Assuming that we're interested in the second usecase more than the first,
# return the pre-release or dev numbering
numbering
=
torch
.
__version__
.
split
(
"."
)
result
=
(
int
(
numbering
[
0
]),
int
(
numbering
[
1
]),
0
)
assert
result
return
result
@
torch_spawn
([
2
])
@
pytest
.
mark
.
xfail
(
torch_version
()
<
(
1
,
6
,
0
),
reason
=
"Doesn't work on torch < 1.6.0"
,
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
...
...
tests/nn/pipe_process/test_rpc.py
View file @
e83da060
...
...
@@ -8,7 +8,7 @@ from torch.distributed import rpc
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
fairscale.nn.pipe
import
PipeRPCWrapper
from
tests.nn.model_parallel.commons
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
def
init_rpc
():
...
...
tests/nn/pipe_process/test_transparency.py
View file @
e83da060
...
...
@@ -22,7 +22,7 @@ import torch
from
torch
import
nn
from
fairscale.nn
import
Pipe
from
tests.nn.model_parallel.commons
import
get_worker_map
,
set_random_seed
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
@
torch_spawn
([
2
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment