Unverified Commit 51c65c97 authored by Yuanhao Zhu's avatar Yuanhao Zhu Committed by GitHub
Browse files

fix syncbn parameter order mismatch and parrots bug (#488)

parent 17e4732c
......@@ -121,9 +121,9 @@ void sync_bn_forward_mean(const Tensor input, Tensor mean);
void sync_bn_forward_var(const Tensor input, const Tensor mean, Tensor var);
void sync_bn_forward_output(const Tensor input, const Tensor mean,
const Tensor var, Tensor running_mean,
Tensor running_var, const Tensor weight,
const Tensor bias, Tensor norm, Tensor std,
const Tensor var, const Tensor weight,
const Tensor bias, Tensor running_mean,
Tensor running_var, Tensor norm, Tensor std,
Tensor output, float eps, float momentum,
int group_size);
......@@ -299,9 +299,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("input"), py::arg("mean"), py::arg("var"));
m.def("sync_bn_forward_output", &sync_bn_forward_output,
"sync_bn forward_output", py::arg("input"), py::arg("mean"),
py::arg("var"), py::arg("running_mean"), py::arg("running_var"),
py::arg("weight"), py::arg("bias"), py::arg("norm"), py::arg("std"),
py::arg("output"), py::arg("eps"), py::arg("momentum"),
py::arg("var"), py::arg("weight"), py::arg("bias"),
py::arg("running_mean"), py::arg("running_var"), py::arg("norm"),
py::arg("std"), py::arg("output"), py::arg("eps"), py::arg("momentum"),
py::arg("group_size"));
m.def("sync_bn_backward_param", &sync_bn_backward_param,
"sync_bn backward_param", py::arg("grad_output"), py::arg("norm"),
......
......@@ -89,9 +89,9 @@ void sync_bn_forward_var(const Tensor input, const Tensor mean, Tensor var) {
}
void sync_bn_forward_output(const Tensor input, const Tensor mean,
const Tensor var, Tensor running_mean,
Tensor running_var, const Tensor weight,
const Tensor bias, Tensor norm, Tensor std,
const Tensor var, const Tensor weight,
const Tensor bias, Tensor running_mean,
Tensor running_var, Tensor norm, Tensor std,
Tensor output, float eps, float momentum,
int group_size) {
if (input.device().is_cuda()) {
......@@ -99,10 +99,10 @@ void sync_bn_forward_output(const Tensor input, const Tensor mean,
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(mean);
CHECK_CUDA_INPUT(var);
CHECK_CUDA_INPUT(running_mean);
CHECK_CUDA_INPUT(running_var);
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(bias);
CHECK_CUDA_INPUT(running_mean);
CHECK_CUDA_INPUT(running_var);
CHECK_CUDA_INPUT(norm);
CHECK_CUDA_INPUT(std);
CHECK_CUDA_INPUT(output);
......
......@@ -52,14 +52,10 @@ class SyncBatchNormFunction(Function):
input3d.size(1), dtype=torch.float, device=input3d.device)
var = torch.empty(
input3d.size(1), dtype=torch.float, device=input3d.device)
if input3d.requires_grad or weight.requires_grad or bias.requires_grad:
norm = torch.empty_like(
input3d, dtype=torch.float, device=input3d.device)
std = torch.empty(
input3d.size(1), dtype=torch.float, device=input3d.device)
else:
norm = torch.empty(0, dtype=torch.float, device=input3d.device)
std = torch.empty(0, dtype=torch.float, device=input3d.device)
norm = torch.empty_like(
input3d, dtype=torch.float, device=input3d.device)
std = torch.empty(
input3d.size(1), dtype=torch.float, device=input3d.device)
ext_module.sync_bn_forward_mean(input3d, mean)
if self.group_size > 1:
......@@ -73,10 +69,10 @@ class SyncBatchNormFunction(Function):
input3d,
mean,
var,
running_mean,
running_var,
weight,
bias,
running_mean,
running_var,
norm,
std,
output3d,
......
......@@ -21,13 +21,13 @@ class TestSyncBN(object):
node_list = str(os.environ['SLURM_NODELIST'])
node_parts = re.findall('[0-9]+', node_list)
host_ip = '{}.{}.{}.{}'.format(node_parts[1], node_parts[2],
node_parts[3], node_parts[4])
port = '12341'
init_method = 'tcp://{}:{}'.format(host_ip, port)
os.environ['MASTER_ADDR'] = (f'{node_parts[1]}.{node_parts[2]}' +
f'.{node_parts[3]}.{node_parts[4]}')
os.environ['MASTER_PORT'] = '12341'
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)
dist.init_process_group(
'nccl', init_method=init_method, world_size=world_size, rank=rank)
dist.init_process_group('nccl')
torch.cuda.set_device(local_rank)
def _test_syncbn_train(self, size=1, half=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment