Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
7b127ccb
Unverified
Commit
7b127ccb
authored
Mar 18, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Mar 18, 2021
Browse files
[fix][OSS] enabling disabled tests for 1.8 (#534)
* enabling disabled tests
parent
8b59267b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
16 deletions
+6
-16
tests/optim/test_oss.py
tests/optim/test_oss.py
+6
-16
No files found.
tests/optim/test_oss.py
View file @
7b127ccb
...
@@ -415,6 +415,7 @@ def test_sharding():
...
@@ -415,6 +415,7 @@ def test_sharding():
def
run_test_collect_shards
(
rank
,
world_size
,
reference_rank
,
tempfile_name
):
def
run_test_collect_shards
(
rank
,
world_size
,
reference_rank
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
dist_init
(
rank
,
world_size
,
tempfile_name
)
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
torch
.
cuda
.
set_device
(
rank
)
# Run a dummy step so that the optimizer state dict exists
# Run a dummy step so that the optimizer state dict exists
batch
,
input_width
,
hidden
,
target_width
=
3
,
3
,
3
,
5
batch
,
input_width
,
hidden
,
target_width
=
3
,
3
,
3
,
5
...
@@ -472,14 +473,10 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
...
@@ -472,14 +473,10 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
# TODO(blefaudeux) Fix for torch v1.8.0
@
skip_if_single_gpu
@
pytest
.
mark
.
skipif
(
torch
.
__version__
.
split
(
"+"
)[
0
].
split
(
"."
)
==
[
"1"
,
"8"
,
"0"
],
reason
=
"disabled for torch 1.8.0"
)
def
test_collect_shards
():
def
test_collect_shards
():
world_size
=
3
world_size
=
2
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
if
torch
.
cuda
.
is_available
():
world_size
=
min
(
world_size
,
torch
.
cuda
.
device_count
())
reference_rank
=
0
reference_rank
=
0
mp
.
spawn
(
mp
.
spawn
(
...
@@ -487,9 +484,10 @@ def test_collect_shards():
...
@@ -487,9 +484,10 @@ def test_collect_shards():
)
)
def
run_test_reproducibility
(
rank
,
world_size
,
reference_rank
,
tempfile_name
):
def
run_test_reproducibility
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
dist_init
(
rank
,
world_size
,
tempfile_name
)
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
torch
.
cuda
.
set_device
(
rank
)
# Run a dummy step so that the optimizer state dict exists
# Run a dummy step so that the optimizer state dict exists
batch
,
input_width
,
hidden
,
target_width
=
3
,
3
,
3
,
5
batch
,
input_width
,
hidden
,
target_width
=
3
,
3
,
3
,
5
...
@@ -535,21 +533,13 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name):
...
@@ -535,21 +533,13 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name):
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
# TODO(blefaudeux) Fix for torch v1.8.0
@
pytest
.
mark
.
skipif
(
torch
.
__version__
.
split
(
"+"
)[
0
].
split
(
"."
)
==
[
"1"
,
"8"
,
"0"
],
reason
=
"disabled for torch 1.8.0"
)
@
skip_if_single_gpu
@
skip_if_single_gpu
def
test_reproducibility
():
def
test_reproducibility
():
world_size
=
2
world_size
=
2
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
if
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
<
world_size
:
# Bail out if not enough devices
return
reference_rank
=
0
mp
.
spawn
(
mp
.
spawn
(
run_test_reproducibility
,
args
=
(
world_size
,
reference_rank
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
run_test_reproducibility
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment