Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Pytorch-Encoding
Commits
58b93a74
Unverified
Commit
58b93a74
authored
Mar 07, 2021
by
zhangbin0917
Committed by
GitHub
Mar 06, 2021
Browse files
fix syncbn bug for pytorch 1.6 (#355)
parent
ced288d6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
8 deletions
+14
-8
encoding/functions/syncbn.py
encoding/functions/syncbn.py
+8
-4
encoding/nn/syncbn.py
encoding/nn/syncbn.py
+6
-4
No files found.
encoding/functions/syncbn.py
View file @
58b93a74
...
...
@@ -103,11 +103,13 @@ class syncbatchnorm_(Function):
# Output
ctx
.
save_for_backward
(
x
,
_ex
,
_exs
,
gamma
,
beta
)
return
y
ctx
.
mark_non_differentiable
(
running_mean
,
running_var
)
return
y
,
running_mean
,
running_var
@
staticmethod
@
once_differentiable
def
backward
(
ctx
,
dz
):
def
backward
(
ctx
,
dz
,
_drunning_mean
,
_drunning_var
):
x
,
_ex
,
_exs
,
gamma
,
beta
=
ctx
.
saved_tensors
dz
=
dz
.
contiguous
()
...
...
@@ -243,11 +245,13 @@ class inp_syncbatchnorm_(Function):
# Output
ctx
.
save_for_backward
(
x
,
_ex
,
_exs
,
gamma
,
beta
)
return
x
ctx
.
mark_non_differentiable
(
running_mean
,
running_var
)
return
x
,
running_mean
,
running_var
@
staticmethod
@
once_differentiable
def
backward
(
ctx
,
dz
):
def
backward
(
ctx
,
dz
,
_drunning_mean
,
_drunning_var
):
z
,
_ex
,
_exs
,
gamma
,
beta
=
ctx
.
saved_tensors
dz
=
dz
.
contiguous
()
...
...
encoding/nn/syncbn.py
View file @
58b93a74
...
...
@@ -193,13 +193,15 @@ class SyncBatchNorm(_BatchNorm):
"worker_queue"
:
self
.
worker_queues
[
self
.
worker_ids
.
index
(
x
.
get_device
())]
}
if
self
.
inplace
:
return
inp_syncbatchnorm
(
x
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
y
,
_
,
_
=
inp_syncbatchnorm
(
x
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
extra
,
self
.
sync
,
self
.
training
,
self
.
momentum
,
self
.
eps
,
self
.
activation
,
self
.
slope
).
view
(
input_shape
)
self
.
activation
,
self
.
slope
)
return
y
.
view
(
input_shape
)
else
:
return
syncbatchnorm
(
x
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
y
,
_
,
_
=
syncbatchnorm
(
x
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
extra
,
self
.
sync
,
self
.
training
,
self
.
momentum
,
self
.
eps
,
self
.
activation
,
self
.
slope
).
view
(
input_shape
)
self
.
activation
,
self
.
slope
)
return
y
.
view
(
input_shape
)
def
extra_repr
(
self
):
if
self
.
activation
==
'none'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment