Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
e6aef938
"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "e0753f0b0d7fbbc07556b3e3d2bf7116b784153d"
Unverified
Commit
e6aef938
authored
Jan 27, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 27, 2021
Browse files
[fix] OSS Cpu tests (#333)
parent
38ad8638
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
3 deletions
+9
-3
tests/optim/test_oss.py
tests/optim/test_oss.py
+9
-3
No files found.
tests/optim/test_oss.py
View file @
e6aef938
...
...
@@ -237,7 +237,7 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
def
test_add_param_group
():
world_size
=
4
if
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
device_count
()
<
world_size
:
if
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
<
world_size
:
world_size
=
min
(
world_size
,
torch
.
cuda
.
device_count
())
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
...
...
@@ -262,6 +262,9 @@ def run_test_zero_grad(rank, world_size, tempfile_name):
def
test_zero_grad
():
world_size
=
2
if
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
<
world_size
:
world_size
=
min
(
world_size
,
torch
.
cuda
.
device_count
())
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_zero_grad
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
...
...
@@ -474,7 +477,11 @@ def run_test_multiple_groups(rank, world_size, tempfile_name):
dist
.
gather
(
p
,
receptacle
,
dst
=
0
,
group
=
process_group
)
if
rank
==
0
:
for
sync_p
in
receptacle
[
1
:]:
assert
torch
.
all
(
torch
.
eq
(
receptacle
[
0
],
sync_p
)),
"Models differ in between ranks"
assert
torch
.
all
(
torch
.
eq
(
receptacle
[
0
],
sync_p
)
),
"Models differ in between ranks {} - {}"
.
format
(
torch
.
norm
(
receptacle
[
0
]),
torch
.
norm
(
sync_p
)
)
if
rank
in
sub_group_ranks
:
# Model fitting in the broadcast bucket
...
...
@@ -498,7 +505,6 @@ def run_test_multiple_groups(rank, world_size, tempfile_name):
check
(
optimizer
)
dist
.
destroy_process_group
(
process_group
)
dist
.
destroy_process_group
()
def
test_multiple_groups
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment