Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
7b127ccb
Unverified
Commit
7b127ccb
authored
Mar 18, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Mar 18, 2021
Browse files
[fix][OSS] enabling disabled tests for 1.8 (#534)
* enabling disabled tests
parent
8b59267b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
16 deletions
+6
-16
tests/optim/test_oss.py
tests/optim/test_oss.py
+6
-16
No files found.
tests/optim/test_oss.py
View file @
7b127ccb
...
@@ -415,6 +415,7 @@ def test_sharding():
...
@@ -415,6 +415,7 @@ def test_sharding():
def
run_test_collect_shards
(
rank
,
world_size
,
reference_rank
,
tempfile_name
):
def
run_test_collect_shards
(
rank
,
world_size
,
reference_rank
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
dist_init
(
rank
,
world_size
,
tempfile_name
)
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
torch
.
cuda
.
set_device
(
rank
)
# Run a dummy step so that the optimizer state dict exists
# Run a dummy step so that the optimizer state dict exists
batch
,
input_width
,
hidden
,
target_width
=
3
,
3
,
3
,
5
batch
,
input_width
,
hidden
,
target_width
=
3
,
3
,
3
,
5
...
@@ -472,14 +473,10 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
...
@@ -472,14 +473,10 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
# TODO(blefaudeux) Fix for torch v1.8.0
@
skip_if_single_gpu
@
pytest
.
mark
.
skipif
(
torch
.
__version__
.
split
(
"+"
)[
0
].
split
(
"."
)
==
[
"1"
,
"8"
,
"0"
],
reason
=
"disabled for torch 1.8.0"
)
def
test_collect_shards
():
def
test_collect_shards
():
world_size
=
3
world_size
=
2
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
if
torch
.
cuda
.
is_available
():
world_size
=
min
(
world_size
,
torch
.
cuda
.
device_count
())
reference_rank
=
0
reference_rank
=
0
mp
.
spawn
(
mp
.
spawn
(
...
@@ -487,9 +484,10 @@ def test_collect_shards():
...
@@ -487,9 +484,10 @@ def test_collect_shards():
)
)
def
run_test_reproducibility
(
rank
,
world_size
,
reference_rank
,
tempfile_name
):
def
run_test_reproducibility
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
dist_init
(
rank
,
world_size
,
tempfile_name
)
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
torch
.
cuda
.
set_device
(
rank
)
# Run a dummy step so that the optimizer state dict exists
# Run a dummy step so that the optimizer state dict exists
batch
,
input_width
,
hidden
,
target_width
=
3
,
3
,
3
,
5
batch
,
input_width
,
hidden
,
target_width
=
3
,
3
,
3
,
5
...
@@ -535,21 +533,13 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name):
...
@@ -535,21 +533,13 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name):
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
# TODO(blefaudeux) Fix for torch v1.8.0
@
pytest
.
mark
.
skipif
(
torch
.
__version__
.
split
(
"+"
)[
0
].
split
(
"."
)
==
[
"1"
,
"8"
,
"0"
],
reason
=
"disabled for torch 1.8.0"
)
@
skip_if_single_gpu
@
skip_if_single_gpu
def
test_reproducibility
():
def
test_reproducibility
():
world_size
=
2
world_size
=
2
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
if
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
<
world_size
:
# Bail out if not enough devices
return
reference_rank
=
0
mp
.
spawn
(
mp
.
spawn
(
run_test_reproducibility
,
args
=
(
world_size
,
reference_rank
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
run_test_reproducibility
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment