Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
51c65c97
Unverified
Commit
51c65c97
authored
Aug 13, 2020
by
Yuanhao Zhu
Committed by
GitHub
Aug 13, 2020
Browse files
fix syncbn parameter order mismatch and parrots bug (#488)
parent
17e4732c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
23 additions
and
27 deletions
+23
-27
mmcv/ops/csrc/pytorch/pybind.cpp
mmcv/ops/csrc/pytorch/pybind.cpp
+6
-6
mmcv/ops/csrc/pytorch/sync_bn.cpp
mmcv/ops/csrc/pytorch/sync_bn.cpp
+5
-5
mmcv/ops/sync_bn.py
mmcv/ops/sync_bn.py
+6
-10
tests/test_ops/test_syncbn.py
tests/test_ops/test_syncbn.py
+6
-6
No files found.
mmcv/ops/csrc/pytorch/pybind.cpp
View file @
51c65c97
...
@@ -121,9 +121,9 @@ void sync_bn_forward_mean(const Tensor input, Tensor mean);
...
@@ -121,9 +121,9 @@ void sync_bn_forward_mean(const Tensor input, Tensor mean);
void
sync_bn_forward_var
(
const
Tensor
input
,
const
Tensor
mean
,
Tensor
var
);
void
sync_bn_forward_var
(
const
Tensor
input
,
const
Tensor
mean
,
Tensor
var
);
void
sync_bn_forward_output
(
const
Tensor
input
,
const
Tensor
mean
,
void
sync_bn_forward_output
(
const
Tensor
input
,
const
Tensor
mean
,
const
Tensor
var
,
Tensor
running_mean
,
const
Tensor
var
,
const
Tensor
weight
,
Tensor
running_var
,
const
Tensor
weight
,
const
Tensor
bias
,
Tensor
running_mean
,
const
Tensor
bias
,
Tensor
norm
,
Tensor
std
,
Tensor
running_var
,
Tensor
norm
,
Tensor
std
,
Tensor
output
,
float
eps
,
float
momentum
,
Tensor
output
,
float
eps
,
float
momentum
,
int
group_size
);
int
group_size
);
...
@@ -299,9 +299,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -299,9 +299,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"input"
),
py
::
arg
(
"mean"
),
py
::
arg
(
"var"
));
py
::
arg
(
"input"
),
py
::
arg
(
"mean"
),
py
::
arg
(
"var"
));
m
.
def
(
"sync_bn_forward_output"
,
&
sync_bn_forward_output
,
m
.
def
(
"sync_bn_forward_output"
,
&
sync_bn_forward_output
,
"sync_bn forward_output"
,
py
::
arg
(
"input"
),
py
::
arg
(
"mean"
),
"sync_bn forward_output"
,
py
::
arg
(
"input"
),
py
::
arg
(
"mean"
),
py
::
arg
(
"var"
),
py
::
arg
(
"
running_mean
"
),
py
::
arg
(
"
running_var
"
),
py
::
arg
(
"var"
),
py
::
arg
(
"
weight
"
),
py
::
arg
(
"
bias
"
),
py
::
arg
(
"
weight
"
),
py
::
arg
(
"
bias
"
),
py
::
arg
(
"norm"
),
py
::
arg
(
"std"
),
py
::
arg
(
"
running_mean
"
),
py
::
arg
(
"
running_var
"
),
py
::
arg
(
"norm"
),
py
::
arg
(
"output"
),
py
::
arg
(
"eps"
),
py
::
arg
(
"momentum"
),
py
::
arg
(
"std"
),
py
::
arg
(
"output"
),
py
::
arg
(
"eps"
),
py
::
arg
(
"momentum"
),
py
::
arg
(
"group_size"
));
py
::
arg
(
"group_size"
));
m
.
def
(
"sync_bn_backward_param"
,
&
sync_bn_backward_param
,
m
.
def
(
"sync_bn_backward_param"
,
&
sync_bn_backward_param
,
"sync_bn backward_param"
,
py
::
arg
(
"grad_output"
),
py
::
arg
(
"norm"
),
"sync_bn backward_param"
,
py
::
arg
(
"grad_output"
),
py
::
arg
(
"norm"
),
...
...
mmcv/ops/csrc/pytorch/sync_bn.cpp
View file @
51c65c97
...
@@ -89,9 +89,9 @@ void sync_bn_forward_var(const Tensor input, const Tensor mean, Tensor var) {
...
@@ -89,9 +89,9 @@ void sync_bn_forward_var(const Tensor input, const Tensor mean, Tensor var) {
}
}
void
sync_bn_forward_output
(
const
Tensor
input
,
const
Tensor
mean
,
void
sync_bn_forward_output
(
const
Tensor
input
,
const
Tensor
mean
,
const
Tensor
var
,
Tensor
running_mean
,
const
Tensor
var
,
const
Tensor
weight
,
Tensor
running_var
,
const
Tensor
weight
,
const
Tensor
bias
,
Tensor
running_mean
,
const
Tensor
bias
,
Tensor
norm
,
Tensor
std
,
Tensor
running_var
,
Tensor
norm
,
Tensor
std
,
Tensor
output
,
float
eps
,
float
momentum
,
Tensor
output
,
float
eps
,
float
momentum
,
int
group_size
)
{
int
group_size
)
{
if
(
input
.
device
().
is_cuda
())
{
if
(
input
.
device
().
is_cuda
())
{
...
@@ -99,10 +99,10 @@ void sync_bn_forward_output(const Tensor input, const Tensor mean,
...
@@ -99,10 +99,10 @@ void sync_bn_forward_output(const Tensor input, const Tensor mean,
CHECK_CUDA_INPUT
(
input
);
CHECK_CUDA_INPUT
(
input
);
CHECK_CUDA_INPUT
(
mean
);
CHECK_CUDA_INPUT
(
mean
);
CHECK_CUDA_INPUT
(
var
);
CHECK_CUDA_INPUT
(
var
);
CHECK_CUDA_INPUT
(
running_mean
);
CHECK_CUDA_INPUT
(
running_var
);
CHECK_CUDA_INPUT
(
weight
);
CHECK_CUDA_INPUT
(
weight
);
CHECK_CUDA_INPUT
(
bias
);
CHECK_CUDA_INPUT
(
bias
);
CHECK_CUDA_INPUT
(
running_mean
);
CHECK_CUDA_INPUT
(
running_var
);
CHECK_CUDA_INPUT
(
norm
);
CHECK_CUDA_INPUT
(
norm
);
CHECK_CUDA_INPUT
(
std
);
CHECK_CUDA_INPUT
(
std
);
CHECK_CUDA_INPUT
(
output
);
CHECK_CUDA_INPUT
(
output
);
...
...
mmcv/ops/sync_bn.py
View file @
51c65c97
...
@@ -52,14 +52,10 @@ class SyncBatchNormFunction(Function):
...
@@ -52,14 +52,10 @@ class SyncBatchNormFunction(Function):
input3d
.
size
(
1
),
dtype
=
torch
.
float
,
device
=
input3d
.
device
)
input3d
.
size
(
1
),
dtype
=
torch
.
float
,
device
=
input3d
.
device
)
var
=
torch
.
empty
(
var
=
torch
.
empty
(
input3d
.
size
(
1
),
dtype
=
torch
.
float
,
device
=
input3d
.
device
)
input3d
.
size
(
1
),
dtype
=
torch
.
float
,
device
=
input3d
.
device
)
if
input3d
.
requires_grad
or
weight
.
requires_grad
or
bias
.
requires_grad
:
norm
=
torch
.
empty_like
(
norm
=
torch
.
empty_like
(
input3d
,
dtype
=
torch
.
float
,
device
=
input3d
.
device
)
input3d
,
dtype
=
torch
.
float
,
device
=
input3d
.
device
)
std
=
torch
.
empty
(
std
=
torch
.
empty
(
input3d
.
size
(
1
),
dtype
=
torch
.
float
,
device
=
input3d
.
device
)
input3d
.
size
(
1
),
dtype
=
torch
.
float
,
device
=
input3d
.
device
)
else
:
norm
=
torch
.
empty
(
0
,
dtype
=
torch
.
float
,
device
=
input3d
.
device
)
std
=
torch
.
empty
(
0
,
dtype
=
torch
.
float
,
device
=
input3d
.
device
)
ext_module
.
sync_bn_forward_mean
(
input3d
,
mean
)
ext_module
.
sync_bn_forward_mean
(
input3d
,
mean
)
if
self
.
group_size
>
1
:
if
self
.
group_size
>
1
:
...
@@ -73,10 +69,10 @@ class SyncBatchNormFunction(Function):
...
@@ -73,10 +69,10 @@ class SyncBatchNormFunction(Function):
input3d
,
input3d
,
mean
,
mean
,
var
,
var
,
running_mean
,
running_var
,
weight
,
weight
,
bias
,
bias
,
running_mean
,
running_var
,
norm
,
norm
,
std
,
std
,
output3d
,
output3d
,
...
...
tests/test_ops/test_syncbn.py
View file @
51c65c97
...
@@ -21,13 +21,13 @@ class TestSyncBN(object):
...
@@ -21,13 +21,13 @@ class TestSyncBN(object):
node_list
=
str
(
os
.
environ
[
'SLURM_NODELIST'
])
node_list
=
str
(
os
.
environ
[
'SLURM_NODELIST'
])
node_parts
=
re
.
findall
(
'[0-9]+'
,
node_list
)
node_parts
=
re
.
findall
(
'[0-9]+'
,
node_list
)
host_ip
=
'{}.{}.{}.{}'
.
format
(
node_parts
[
1
],
node_parts
[
2
],
os
.
environ
[
'MASTER_ADDR'
]
=
(
f
'
{
node_parts
[
1
]
}
.
{
node_parts
[
2
]
}
'
+
node_parts
[
3
],
node_parts
[
4
])
f
'.
{
node_parts
[
3
]
}
.
{
node_parts
[
4
]
}
'
)
port
=
'12341'
os
.
environ
[
'MASTER_PORT'
]
=
'12341'
init_method
=
'tcp://{}:{}'
.
format
(
host_ip
,
port
)
os
.
environ
[
'WORLD_SIZE'
]
=
str
(
world_size
)
os
.
environ
[
'RANK'
]
=
str
(
rank
)
dist
.
init_process_group
(
dist
.
init_process_group
(
'nccl'
)
'nccl'
,
init_method
=
init_method
,
world_size
=
world_size
,
rank
=
rank
)
torch
.
cuda
.
set_device
(
local_rank
)
torch
.
cuda
.
set_device
(
local_rank
)
def
_test_syncbn_train
(
self
,
size
=
1
,
half
=
False
):
def
_test_syncbn_train
(
self
,
size
=
1
,
half
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment