Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
e83da060
Unverified
Commit
e83da060
authored
Dec 01, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Dec 01, 2020
Browse files
[chore] Refactor unit testing, shared utils (#218)
parent
1db8bbda
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
117 additions
and
70 deletions
+117
-70
benchmarks/pipe.py
benchmarks/pipe.py
+1
-1
fairscale/utils/testing.py
fairscale/utils/testing.py
+48
-17
pyproject.toml
pyproject.toml
+1
-1
stubs/torch/cuda/__init__.pyi
stubs/torch/cuda/__init__.pyi
+7
-0
stubs/torch/distributed/__init__.pyi
stubs/torch/distributed/__init__.pyi
+3
-1
stubs/torch/distributed/rpc/__init__.pyi
stubs/torch/distributed/rpc/__init__.pyi
+11
-1
tests/nn/model_parallel/test_cross_entropy.py
tests/nn/model_parallel/test_cross_entropy.py
+1
-1
tests/nn/model_parallel/test_initialize.py
tests/nn/model_parallel/test_initialize.py
+1
-1
tests/nn/model_parallel/test_layers.py
tests/nn/model_parallel/test_layers.py
+1
-7
tests/nn/model_parallel/test_random.py
tests/nn/model_parallel/test_random.py
+1
-1
tests/nn/pipe/conftest.py
tests/nn/pipe/conftest.py
+29
-7
tests/nn/pipe_process/conftest.py
tests/nn/pipe_process/conftest.py
+6
-5
tests/nn/pipe_process/skip/test_gpipe.py
tests/nn/pipe_process/skip/test_gpipe.py
+1
-1
tests/nn/pipe_process/skip/test_leak.py
tests/nn/pipe_process/skip/test_leak.py
+1
-1
tests/nn/pipe_process/test_bugs.py
tests/nn/pipe_process/test_bugs.py
+1
-1
tests/nn/pipe_process/test_inplace.py
tests/nn/pipe_process/test_inplace.py
+1
-1
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+1
-21
tests/nn/pipe_process/test_rpc.py
tests/nn/pipe_process/test_rpc.py
+1
-1
tests/nn/pipe_process/test_transparency.py
tests/nn/pipe_process/test_transparency.py
+1
-1
No files found.
benchmarks/pipe.py
View file @
e83da060
...
@@ -23,7 +23,7 @@ from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_
...
@@ -23,7 +23,7 @@ from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_
from
fairscale.nn.pipe
import
LazyModule
,
pipe
from
fairscale.nn.pipe
import
LazyModule
,
pipe
from
fairscale.optim
import
GradScaler
from
fairscale.optim
import
GradScaler
from
fairscale.optim.oss
import
OSS
from
fairscale.optim.oss
import
OSS
from
tests.nn.model_parallel.commons
import
dist_init
,
get_worker_map
from
fairscale.utils.testing
import
dist_init
,
get_worker_map
try
:
try
:
from
fairscale.optim
import
Adam
# type: ignore
from
fairscale.optim
import
Adam
# type: ignore
...
...
tests/nn/model_parallel/commons
.py
→
fairscale/utils/testing
.py
View file @
e83da060
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
#
# This source code is licensed under the BSD license found in the
# This source code is licensed under the BSD license found in the
...
@@ -11,7 +9,7 @@
...
@@ -11,7 +9,7 @@
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -19,14 +17,23 @@
...
@@ -19,14 +17,23 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# We're not responsible for pytest decorators
# mypy: disallow_untyped_decorators = False
"""
Collection of some testing utilities for the Fairscale library. Please complement as you see fit, but refrain from ad-hoc test utils
within the different feature sets and relative imports.
"""
import
functools
import
functools
import
inspect
import
inspect
import
logging
import
multiprocessing
import
multiprocessing
import
os
import
os
import
random
import
random
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
numpy
import
numpy
from
packaging
import
version
import
pytest
import
pytest
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -38,11 +45,11 @@ from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
...
@@ -38,11 +45,11 @@ from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
class
IdentityLayer
(
torch
.
nn
.
Module
):
class
IdentityLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
size
,
scale
=
1.0
)
:
def
__init__
(
self
,
size
:
int
,
scale
:
float
=
1.0
)
->
None
:
super
(
IdentityLayer
,
self
).
__init__
()
super
(
IdentityLayer
,
self
).
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
scale
*
torch
.
randn
(
size
))
self
.
weight
=
torch
.
nn
.
Parameter
(
scale
*
torch
.
randn
(
size
))
def
forward
(
self
)
:
def
forward
(
self
,
*
_
:
Any
,
**
__
:
Any
)
->
Any
:
return
self
.
weight
return
self
.
weight
...
@@ -54,7 +61,26 @@ def set_random_seed(seed: int) -> None:
...
@@ -54,7 +61,26 @@ def set_random_seed(seed: int) -> None:
model_parallel_cuda_manual_seed
(
seed
)
model_parallel_cuda_manual_seed
(
seed
)
def
dist_init
(
rank
,
world_size
,
hostname
=
None
):
def
torch_version
()
->
Tuple
[
int
,
...]:
numbering
=
torch
.
__version__
.
split
(
"."
)
assert
len
(
numbering
)
==
3
# Catch torch version if run against internal pre-releases, like `1.8.0a0fb`,
if
not
numbering
[
2
].
isnumeric
():
# Two options here:
# - either skip this version (minor number check is not relevant)
# - or check that our codebase is not broken by this ongoing development.
# Assuming that we're interested in the second usecase more than the first,
# return the pre-release or dev numbering
logging
.
warning
(
f
"Pytorch pre-relase version
{
torch
.
__version__
}
- assuming intent to test it"
)
numbering
[
2
]
=
"0"
return
tuple
(
int
(
n
)
for
n
in
numbering
)
def
dist_init
(
rank
:
int
,
world_size
:
int
,
hostname
:
Optional
[
str
]
=
None
)
->
None
:
if
hostname
is
None
:
if
hostname
is
None
:
hostname
=
"localhost"
hostname
=
"localhost"
print
(
f
"dist init r=
{
rank
}
, world=
{
world_size
}
, host=
{
hostname
}
"
)
print
(
f
"dist init r=
{
rank
}
, world=
{
world_size
}
, host=
{
hostname
}
"
)
...
@@ -63,7 +89,7 @@ def dist_init(rank, world_size, hostname=None):
...
@@ -63,7 +89,7 @@ def dist_init(rank, world_size, hostname=None):
os
.
environ
[
"WORLD_SIZE"
]
=
str
(
world_size
)
os
.
environ
[
"WORLD_SIZE"
]
=
str
(
world_size
)
os
.
environ
[
"RANK"
]
=
str
(
rank
)
os
.
environ
[
"RANK"
]
=
str
(
rank
)
if
version
.
parse
(
torch
.
_
_version
__
).
release
>=
(
1
,
6
,
0
):
if
torch_version
()
>=
(
1
,
6
,
0
):
init_method
=
f
"tcp://
{
os
.
environ
[
'MASTER_ADDR'
]
}
:
{
os
.
environ
[
'MASTER_PORT'
]
}
"
init_method
=
f
"tcp://
{
os
.
environ
[
'MASTER_ADDR'
]
}
:
{
os
.
environ
[
'MASTER_PORT'
]
}
"
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
,
init_method
=
init_method
)
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
,
init_method
=
init_method
)
...
@@ -77,6 +103,7 @@ def dist_init(rank, world_size, hostname=None):
...
@@ -77,6 +103,7 @@ def dist_init(rank, world_size, hostname=None):
backend
=
rpc
.
BackendType
.
TENSORPIPE
,
backend
=
rpc
.
BackendType
.
TENSORPIPE
,
rpc_backend_options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
init_method
),
rpc_backend_options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
init_method
),
)
)
else
:
else
:
if
world_size
>
1
:
if
world_size
>
1
:
rpc
.
init_rpc
(
f
"Test
{
rank
}
"
,
rank
=
rank
,
world_size
=
world_size
)
rpc
.
init_rpc
(
f
"Test
{
rank
}
"
,
rank
=
rank
,
world_size
=
world_size
)
...
@@ -87,21 +114,21 @@ def dist_init(rank, world_size, hostname=None):
...
@@ -87,21 +114,21 @@ def dist_init(rank, world_size, hostname=None):
torch
.
cuda
.
set_device
(
rank
%
torch
.
cuda
.
device_count
())
torch
.
cuda
.
set_device
(
rank
%
torch
.
cuda
.
device_count
())
def
get_worker_map
():
def
get_worker_map
()
->
Dict
[
Any
,
Any
]
:
return
{
rank
:
f
"Test
{
rank
}
"
for
rank
in
range
(
dist
.
get_world_size
())}
return
{
rank
:
f
"Test
{
rank
}
"
for
rank
in
range
(
dist
.
get_world_size
())}
def
get_world_sizes
():
def
get_world_sizes
()
->
List
[
int
]
:
limit
=
torch
.
cuda
.
device_count
()
limit
=
torch
.
cuda
.
device_count
()
return
[
x
for
x
in
[
1
,
2
,
4
,
8
]
if
x
<=
limit
]
return
[
x
for
x
in
[
1
,
2
,
4
,
8
]
if
x
<=
limit
]
def
spawn_for_all_world_sizes
(
test_func
,
world_sizes
=
get_world_sizes
(),
args
=
[])
:
def
spawn_for_all_world_sizes
(
test_func
:
Callable
,
world_sizes
:
List
[
int
]
=
get_world_sizes
(),
args
:
Any
=
[])
->
None
:
for
world_size
in
world_sizes
:
for
world_size
in
world_sizes
:
mp
.
spawn
(
test_func
,
args
=
(
world_size
,
*
args
),
nprocs
=
world_size
,
join
=
True
)
mp
.
spawn
(
test_func
,
args
=
(
world_size
,
*
args
),
nprocs
=
world_size
,
join
=
True
)
# type: ignore
def
worker_process
(
rank
,
world_size
,
func
,
args
,
error_queue
)
:
def
worker_process
(
rank
:
int
,
world_size
:
int
,
func
:
Callable
,
args
:
Any
,
error_queue
:
Any
)
->
None
:
"""Main function for unit tests launced with torch_spawn"""
"""Main function for unit tests launced with torch_spawn"""
dist_init
(
rank
,
world_size
)
dist_init
(
rank
,
world_size
)
...
@@ -120,11 +147,11 @@ def worker_process(rank, world_size, func, args, error_queue):
...
@@ -120,11 +147,11 @@ def worker_process(rank, world_size, func, args, error_queue):
raise
e
raise
e
def
torch_spawn
(
world_sizes
=
None
)
:
def
torch_spawn
(
world_sizes
:
Optional
[
List
[
int
]]
=
None
)
->
Callable
:
if
world_sizes
is
None
:
if
world_sizes
is
None
:
world_sizes
=
get_world_sizes
()
world_sizes
=
get_world_sizes
()
def
prepare_test
(
func
)
:
def
prepare_test
(
func
:
Callable
)
->
Callable
:
"""Function called with the test function as the argument. Generates a
"""Function called with the test function as the argument. Generates a
replacement which serves as the actual test function."""
replacement which serves as the actual test function."""
...
@@ -138,8 +165,10 @@ def torch_spawn(world_sizes=None):
...
@@ -138,8 +165,10 @@ def torch_spawn(world_sizes=None):
)
)
@
functools
.
wraps
(
func
)
@
functools
.
wraps
(
func
)
def
replacement
(
*
args
,
**
kwargs
)
:
def
replacement
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
assert
args
==
tuple
()
assert
args
==
tuple
()
assert
world_sizes
is
not
None
# mypy crutch
args
=
tuple
(
args
=
tuple
(
kwargs
[
p
]
for
p
in
parameters
if
p
!=
"rank"
kwargs
[
p
]
for
p
in
parameters
if
p
!=
"rank"
)
# converting named parameters to positional parameters to pass to `spawn`
)
# converting named parameters to positional parameters to pass to `spawn`
...
@@ -174,7 +203,9 @@ def torch_spawn(world_sizes=None):
...
@@ -174,7 +203,9 @@ def torch_spawn(world_sizes=None):
# Register a function with the same name, prefixed with "test_" in the
# Register a function with the same name, prefixed with "test_" in the
# calling module, so it will be picked up by pytest
# calling module, so it will be picked up by pytest
caller_module
=
inspect
.
getmodule
(
inspect
.
currentframe
().
f_back
)
current_frame
=
inspect
.
currentframe
()
assert
current_frame
is
not
None
caller_module
=
inspect
.
getmodule
(
current_frame
.
f_back
)
setattr
(
caller_module
,
f
"test_
{
name
}
"
,
replacement
)
setattr
(
caller_module
,
f
"test_
{
name
}
"
,
replacement
)
return
func
return
func
...
...
pyproject.toml
View file @
e83da060
...
@@ -28,4 +28,4 @@ use_parentheses = true
...
@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob
=
[
"build/*"
,
"stubs/*"
]
skip_glob
=
[
"build/*"
,
"stubs/*"
]
# Don't split "import" and "from".
# Don't split "import" and "from".
force_sort_within_sections
=
true
force_sort_within_sections
=
true
known_third_party
=
[
"benchmark_dataset"
,
"dataclasses"
,
"numpy"
,
"packaging"
,
"pytest"
,
"recommonmark"
,
"setuptools"
,
"torch"
,
"torch_pg"
,
"torchtext"
,
"torchvision"
]
known_third_party
=
[
"benchmark_dataset"
,
"dataclasses"
,
"numpy"
,
"pytest"
,
"recommonmark"
,
"setuptools"
,
"torch"
,
"torch_pg"
,
"torchtext"
,
"torchvision"
]
stubs/torch/cuda/__init__.pyi
View file @
e83da060
...
@@ -7,6 +7,7 @@ from .. import device as _device
...
@@ -7,6 +7,7 @@ from .. import device as _device
def is_available() -> bool: ...
def is_available() -> bool: ...
def init() -> None: ...
def init() -> None: ...
def _lazy_call(callable) -> None: ...
def _lazy_call(callable) -> None: ...
def _sleep(_:int) -> None : ...
class cudaStatus:
class cudaStatus:
SUCCESS: int
SUCCESS: int
...
@@ -64,6 +65,12 @@ class Stream:
...
@@ -64,6 +65,12 @@ class Stream:
def synchronize(self) -> None: ...
def synchronize(self) -> None: ...
def wait_stream(self, stream: Stream) -> None: ...
def wait_stream(self, stream: Stream) -> None: ...
class Event:
def __new__(cls, enable_timing: bool = False, blocking:bool = False, interprocess: bool = False) -> "Event": ...
def record(self, stream: Optional[Stream] = None) -> None: ...
def synchronize(self) -> None: ...
def elapsed_time(self, end_event: Event) -> int: ...
class stream:
class stream:
def __init__(self, stream: Optional[Stream] = ...) -> None: ...
def __init__(self, stream: Optional[Stream] = ...) -> None: ...
def __enter__(self) -> None: ...
def __enter__(self) -> None: ...
...
...
stubs/torch/distributed/__init__.pyi
View file @
e83da060
...
@@ -35,7 +35,7 @@ def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optio
...
@@ -35,7 +35,7 @@ def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optio
def is_initialized() -> bool: ...
def is_initialized() -> bool: ...
def init_process_group(backend: Union[str, Backend], timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ...
def init_process_group(backend: Union[str, Backend],
init_method: Optional[str] = None,
timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ...
def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ...
def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ...
def all_to_all(output: List[Tensor], input: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_to_all(output: List[Tensor], input: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ...
...
@@ -43,6 +43,8 @@ def all_to_all_single(output: Tensor, input: Tensor, output_split_size: Optional
...
@@ -43,6 +43,8 @@ def all_to_all_single(output: Tensor, input: Tensor, output_split_size: Optional
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def destroy_process_group() -> None: ...
def send(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ...
def send(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ...
def isend(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ...
def isend(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ...
def recv(tensor: Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> int: ...
def recv(tensor: Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> int: ...
...
...
stubs/torch/distributed/rpc/__init__.pyi
View file @
e83da060
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Union, Callable, Optional
from typing import Union, Callable, Optional
, Any
from torch.futures import Future
from torch.futures import Future
...
@@ -11,6 +11,11 @@ class RRef:
...
@@ -11,6 +11,11 @@ class RRef:
class WorkerInfo:
class WorkerInfo:
...
...
class BackendType:
TENSORPIPE: Any
PROCESS_GROUP: Any
def TensorPipeRpcBackendOptions(init_method: str) -> Any : ...
def rpc_async(
def rpc_async(
to: Union[str, WorkerInfo],
to: Union[str, WorkerInfo],
...
@@ -30,3 +35,8 @@ def rpc_sync(
...
@@ -30,3 +35,8 @@ def rpc_sync(
timeout=-1.0,
timeout=-1.0,
) -> None:
) -> None:
...
...
def init_rpc(name: str, backend: Optional[Any] = None, rank:int = -1, world_size: Optional[int] = None, rpc_backend_options: Optional[Any] = None) -> None: ...
def shutdown() -> None: ...
tests/nn/model_parallel/test_cross_entropy.py
View file @
e83da060
...
@@ -25,7 +25,7 @@ import torch.nn.functional as F
...
@@ -25,7 +25,7 @@ import torch.nn.functional as F
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel.cross_entropy
import
vocab_parallel_cross_entropy
from
fairscale.nn.model_parallel.cross_entropy
import
vocab_parallel_cross_entropy
from
fairscale.nn.model_parallel.mappings
import
scatter_to_model_parallel_region
from
fairscale.nn.model_parallel.mappings
import
scatter_to_model_parallel_region
from
tests.nn.model_parallel.commons
import
IdentityLayer
,
dist_init
,
set_random_seed
,
spawn_for_all_world_sizes
from
fairscale.utils.testing
import
IdentityLayer
,
dist_init
,
set_random_seed
,
spawn_for_all_world_sizes
def
torch_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
):
def
torch_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
):
...
...
tests/nn/model_parallel/test_initialize.py
View file @
e83da060
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
import
torch
import
torch
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
tests.nn.model_parallel.commons
import
dist_init
,
spawn_for_all_world_sizes
from
fairscale.utils.testing
import
dist_init
,
spawn_for_all_world_sizes
def
run_test_initialize_model_parallel
(
rank
,
model_parallel_size
):
def
run_test_initialize_model_parallel
(
rank
,
model_parallel_size
):
...
...
tests/nn/model_parallel/test_layers.py
View file @
e83da060
...
@@ -31,13 +31,7 @@ from torch.nn.parameter import Parameter
...
@@ -31,13 +31,7 @@ from torch.nn.parameter import Parameter
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
layers
from
fairscale.nn.model_parallel
import
layers
from
fairscale.nn.pipe
import
Pipe
from
fairscale.nn.pipe
import
Pipe
from
tests.nn.model_parallel.commons
import
(
from
fairscale.utils.testing
import
dist_init
,
get_world_sizes
,
set_random_seed
,
spawn_for_all_world_sizes
,
torch_spawn
dist_init
,
get_world_sizes
,
set_random_seed
,
spawn_for_all_world_sizes
,
torch_spawn
,
)
def
run_test_parallel_embedding
(
rank
,
model_parallel_size
):
def
run_test_parallel_embedding
(
rank
,
model_parallel_size
):
...
...
tests/nn/model_parallel/test_random.py
View file @
e83da060
...
@@ -24,7 +24,7 @@ import torch
...
@@ -24,7 +24,7 @@ import torch
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
random
from
fairscale.nn.model_parallel
import
random
from
fairscale.nn.model_parallel.random
import
get_cuda_rng_tracker
,
model_parallel_cuda_manual_seed
from
fairscale.nn.model_parallel.random
import
get_cuda_rng_tracker
,
model_parallel_cuda_manual_seed
from
tests.nn.model_parallel.commons
import
dist_init
,
spawn_for_all_world_sizes
from
fairscale.utils.testing
import
dist_init
,
spawn_for_all_world_sizes
def
run_test_set_cuda_rng_state
(
rank
,
model_parallel_size
):
def
run_test_set_cuda_rng_state
(
rank
,
model_parallel_size
):
...
...
tests/nn/pipe/conftest.py
View file @
e83da060
...
@@ -17,17 +17,27 @@
...
@@ -17,17 +17,27 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
functools
import
os
from
typing
import
Any
,
Callable
import
pytest
import
pytest
import
torch
import
torch
from
fairscale.nn.model_parallel
import
destroy_model_parallel
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
def
manual_seed_zero
():
def
manual_seed_zero
()
->
None
:
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
def
cuda_sleep_impl
(
seconds
,
cycles_per_ms
):
torch
.
cuda
.
_sleep
(
int
(
seconds
*
cycles_per_ms
*
1000
))
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
cuda_sleep
():
def
cuda_sleep
()
->
Callable
:
# Warm-up CUDA.
# Warm-up CUDA.
torch
.
empty
(
1
,
device
=
"cuda"
)
torch
.
empty
(
1
,
device
=
"cuda"
)
...
@@ -40,11 +50,23 @@ def cuda_sleep():
...
@@ -40,11 +50,23 @@ def cuda_sleep():
end
.
synchronize
()
end
.
synchronize
()
cycles_per_ms
=
1000000
/
start
.
elapsed_time
(
end
)
cycles_per_ms
=
1000000
/
start
.
elapsed_time
(
end
)
def
cuda_sleep
(
seconds
):
return
functools
.
partial
(
cuda_sleep_impl
,
cycles_per_ms
=
cycles_per_ms
)
torch
.
cuda
.
_sleep
(
int
(
seconds
*
cycles_per_ms
*
1000
))
return
cuda_sleep
def
pytest_report_header
():
def
pytest_report_header
()
->
str
:
return
f
"torch:
{
torch
.
__version__
}
"
return
f
"torch:
{
torch
.
__version__
}
"
def
pytest_runtest_setup
(
item
:
Any
)
->
None
:
print
(
f
"setup mpi function called"
)
def
pytest_runtest_teardown
(
item
:
Any
)
->
None
:
if
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
:
destroy_model_parallel
()
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
destroy_process_group
()
try
:
torch
.
distributed
.
rpc
.
shutdown
()
except
Exception
:
pass
tests/nn/pipe_process/conftest.py
View file @
e83da060
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
import
functools
import
functools
import
os
import
os
from
typing
import
Any
,
Callable
import
pytest
import
pytest
import
torch
import
torch
...
@@ -27,7 +28,7 @@ from fairscale.nn.model_parallel import destroy_model_parallel
...
@@ -27,7 +28,7 @@ from fairscale.nn.model_parallel import destroy_model_parallel
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
def
manual_seed_zero
():
def
manual_seed_zero
()
->
None
:
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
@@ -36,7 +37,7 @@ def cuda_sleep_impl(seconds, cycles_per_ms):
...
@@ -36,7 +37,7 @@ def cuda_sleep_impl(seconds, cycles_per_ms):
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
cuda_sleep
():
def
cuda_sleep
()
->
Callable
:
# Warm-up CUDA.
# Warm-up CUDA.
torch
.
empty
(
1
,
device
=
"cuda"
)
torch
.
empty
(
1
,
device
=
"cuda"
)
...
@@ -52,15 +53,15 @@ def cuda_sleep():
...
@@ -52,15 +53,15 @@ def cuda_sleep():
return
functools
.
partial
(
cuda_sleep_impl
,
cycles_per_ms
=
cycles_per_ms
)
return
functools
.
partial
(
cuda_sleep_impl
,
cycles_per_ms
=
cycles_per_ms
)
def
pytest_report_header
():
def
pytest_report_header
()
->
str
:
return
f
"torch:
{
torch
.
__version__
}
"
return
f
"torch:
{
torch
.
__version__
}
"
def
pytest_runtest_setup
(
item
)
:
def
pytest_runtest_setup
(
item
:
Any
)
->
None
:
print
(
f
"setup mpi function called"
)
print
(
f
"setup mpi function called"
)
def
pytest_runtest_teardown
(
item
)
:
def
pytest_runtest_teardown
(
item
:
Any
)
->
None
:
if
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
:
if
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
:
destroy_model_parallel
()
destroy_model_parallel
()
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
is_initialized
():
...
...
tests/nn/pipe_process/skip/test_gpipe.py
View file @
e83da060
...
@@ -26,7 +26,7 @@ from torch import nn
...
@@ -26,7 +26,7 @@ from torch import nn
from
fairscale.nn.pipe
import
LazyModule
,
Pipe
from
fairscale.nn.pipe
import
LazyModule
,
Pipe
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip.portal
import
PortalBlue
,
PortalCopy
,
PortalOrange
from
fairscale.nn.pipe.skip.portal
import
PortalBlue
,
PortalCopy
,
PortalOrange
from
tests.nn.model_parallel.commons
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
3
])
@
torch_spawn
([
3
])
...
...
tests/nn/pipe_process/skip/test_leak.py
View file @
e83da060
...
@@ -26,7 +26,7 @@ from torch import nn
...
@@ -26,7 +26,7 @@ from torch import nn
from
fairscale.nn.pipe
import
Pipe
,
is_checkpointing
,
is_recomputing
from
fairscale.nn.pipe
import
Pipe
,
is_checkpointing
,
is_recomputing
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip.tracker
import
current_skip_tracker
from
fairscale.nn.pipe.skip.tracker
import
current_skip_tracker
from
tests.nn.model_parallel.commons
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
skippable
(
stash
=
[
"skip"
])
@
skippable
(
stash
=
[
"skip"
])
...
...
tests/nn/pipe_process/test_bugs.py
View file @
e83da060
...
@@ -23,7 +23,7 @@ from torch import nn
...
@@ -23,7 +23,7 @@ from torch import nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
fairscale.nn.pipe
import
Pipe
from
fairscale.nn.pipe
import
Pipe
from
tests.nn.model_parallel.commons
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
...
...
tests/nn/pipe_process/test_inplace.py
View file @
e83da060
...
@@ -22,7 +22,7 @@ import torch
...
@@ -22,7 +22,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn.pipe
import
Pipe
from
fairscale.nn.pipe
import
Pipe
from
tests.nn.model_parallel.commons
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
...
...
tests/nn/pipe_process/test_pipe.py
View file @
e83da060
...
@@ -21,9 +21,7 @@ from collections import OrderedDict
...
@@ -21,9 +21,7 @@ from collections import OrderedDict
from
copy
import
deepcopy
from
copy
import
deepcopy
import
os
import
os
import
time
import
time
from
typing
import
Tuple
from
packaging
import
version
import
pytest
import
pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -34,7 +32,7 @@ from fairscale.nn.model_parallel.initialize import (
...
@@ -34,7 +32,7 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel
,
initialize_model_parallel
,
)
)
from
fairscale.nn.pipe
import
LazyModule
,
Pipe
from
fairscale.nn.pipe
import
LazyModule
,
Pipe
from
tests.nn.model_parallel.commons
import
get_worker_map
,
set_random_seed
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
,
torch_version
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
...
@@ -373,24 +371,6 @@ def checkpoint_eval(pipeline_style):
...
@@ -373,24 +371,6 @@ def checkpoint_eval(pipeline_style):
assert
not
find_grad_fn
(
eval_output
.
grad_fn
,
"RecomputeBackward"
)
assert
not
find_grad_fn
(
eval_output
.
grad_fn
,
"RecomputeBackward"
)
def
torch_version
()
->
Tuple
[
int
,
...]:
result
=
version
.
parse
(
torch
.
__version__
).
release
# Catch torch version if run against internal pre-releases, like `1.8.0a0fb`,
# for which version.parse().release will return None (version becomes of LegacyVersion type)
if
result
is
None
:
# Two options here:
# - either skip this version,
# - or check that Pipe is not broken by this ongoing development.
# Assuming that we're interested in the second usecase more than the first,
# return the pre-release or dev numbering
numbering
=
torch
.
__version__
.
split
(
"."
)
result
=
(
int
(
numbering
[
0
]),
int
(
numbering
[
1
]),
0
)
assert
result
return
result
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
xfail
(
torch_version
()
<
(
1
,
6
,
0
),
reason
=
"Doesn't work on torch < 1.6.0"
,
strict
=
True
)
@
pytest
.
mark
.
xfail
(
torch_version
()
<
(
1
,
6
,
0
),
reason
=
"Doesn't work on torch < 1.6.0"
,
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
...
...
tests/nn/pipe_process/test_rpc.py
View file @
e83da060
...
@@ -8,7 +8,7 @@ from torch.distributed import rpc
...
@@ -8,7 +8,7 @@ from torch.distributed import rpc
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
fairscale.nn.pipe
import
PipeRPCWrapper
from
fairscale.nn.pipe
import
PipeRPCWrapper
from
tests.nn.model_parallel.commons
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
def
init_rpc
():
def
init_rpc
():
...
...
tests/nn/pipe_process/test_transparency.py
View file @
e83da060
...
@@ -22,7 +22,7 @@ import torch
...
@@ -22,7 +22,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn
import
Pipe
from
fairscale.nn
import
Pipe
from
tests.nn.model_parallel.commons
import
get_worker_map
,
set_random_seed
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment