Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
ed47ebff
Commit
ed47ebff
authored
Sep 18, 2018
by
Michael Carilli
Browse files
Forward compatibility fixes for distributed backend, thanks to @Ssnl
parent
0ec8addb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
3 deletions
+13
-3
apex/parallel/distributed.py
apex/parallel/distributed.py
+13
-3
No files found.
apex/parallel/distributed.py
View file @
ed47ebff
...
@@ -129,7 +129,17 @@ class DistributedDataParallel(Module):
...
@@ -129,7 +129,17 @@ class DistributedDataParallel(Module):
def
__init__
(
self
,
module
,
message_size
=
10000000
,
shared_param
=
False
):
def
__init__
(
self
,
module
,
message_size
=
10000000
,
shared_param
=
False
):
super
(
DistributedDataParallel
,
self
).
__init__
()
super
(
DistributedDataParallel
,
self
).
__init__
()
self
.
warn_on_half
=
True
if
dist
.
_backend
==
dist
.
dist_backend
.
GLOO
else
False
# Backward/forward compatibility around
# https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36
if
(
hasattr
(
dist
,
"get_backend"
)):
self
.
_backend
=
dist
.
get_backend
()
self
.
backend_enum_holder
=
dist
.
DistBackend
else
:
self
.
_backend
=
dist
.
_backend
self
.
backend_enum_holder
=
dist
.
dist_backend
self
.
warn_on_half
=
True
if
self
.
_backend
==
self
.
backend_enum_holder
.
GLOO
else
False
self
.
shared_param
=
shared_param
self
.
shared_param
=
shared_param
self
.
message_size
=
message_size
self
.
message_size
=
message_size
...
@@ -141,7 +151,7 @@ class DistributedDataParallel(Module):
...
@@ -141,7 +151,7 @@ class DistributedDataParallel(Module):
self
.
module
=
module
self
.
module
=
module
self
.
param_list
=
list
(
self
.
module
.
parameters
())
self
.
param_list
=
list
(
self
.
module
.
parameters
())
if
dist
.
_backend
==
dist
.
dist_backend
.
NCCL
:
if
self
.
_backend
==
self
.
backend_enum_holder
.
NCCL
:
for
param
in
self
.
param_list
:
for
param
in
self
.
param_list
:
assert
param
.
is_cuda
,
"NCCL backend only supports model parameters to be on GPU."
assert
param
.
is_cuda
,
"NCCL backend only supports model parameters to be on GPU."
...
@@ -156,7 +166,7 @@ class DistributedDataParallel(Module):
...
@@ -156,7 +166,7 @@ class DistributedDataParallel(Module):
def
__getstate__
(
self
):
def
__getstate__
(
self
):
attrs
=
copy
.
copy
(
self
.
__dict__
)
attrs
=
copy
.
copy
(
self
.
__dict__
)
if
dist
.
_backend
!=
dist
.
dist_backend
.
NCCL
:
if
self
.
_backend
!=
self
.
backend_enum_holder
.
NCCL
:
del
attrs
[
'self.reduction_stream'
]
del
attrs
[
'self.reduction_stream'
]
return
attrs
return
attrs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment