Unverified Commit 51c65c97 authored by Yuanhao Zhu's avatar Yuanhao Zhu Committed by GitHub
Browse files

fix syncbn parameter order mismatch and parrots bug (#488)

parent 17e4732c
...@@ -121,9 +121,9 @@ void sync_bn_forward_mean(const Tensor input, Tensor mean); ...@@ -121,9 +121,9 @@ void sync_bn_forward_mean(const Tensor input, Tensor mean);
void sync_bn_forward_var(const Tensor input, const Tensor mean, Tensor var); void sync_bn_forward_var(const Tensor input, const Tensor mean, Tensor var);
void sync_bn_forward_output(const Tensor input, const Tensor mean, void sync_bn_forward_output(const Tensor input, const Tensor mean,
const Tensor var, Tensor running_mean, const Tensor var, const Tensor weight,
Tensor running_var, const Tensor weight, const Tensor bias, Tensor running_mean,
const Tensor bias, Tensor norm, Tensor std, Tensor running_var, Tensor norm, Tensor std,
Tensor output, float eps, float momentum, Tensor output, float eps, float momentum,
int group_size); int group_size);
...@@ -299,9 +299,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -299,9 +299,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("input"), py::arg("mean"), py::arg("var")); py::arg("input"), py::arg("mean"), py::arg("var"));
m.def("sync_bn_forward_output", &sync_bn_forward_output, m.def("sync_bn_forward_output", &sync_bn_forward_output,
"sync_bn forward_output", py::arg("input"), py::arg("mean"), "sync_bn forward_output", py::arg("input"), py::arg("mean"),
py::arg("var"), py::arg("running_mean"), py::arg("running_var"), py::arg("var"), py::arg("weight"), py::arg("bias"),
py::arg("weight"), py::arg("bias"), py::arg("norm"), py::arg("std"), py::arg("running_mean"), py::arg("running_var"), py::arg("norm"),
py::arg("output"), py::arg("eps"), py::arg("momentum"), py::arg("std"), py::arg("output"), py::arg("eps"), py::arg("momentum"),
py::arg("group_size")); py::arg("group_size"));
m.def("sync_bn_backward_param", &sync_bn_backward_param, m.def("sync_bn_backward_param", &sync_bn_backward_param,
"sync_bn backward_param", py::arg("grad_output"), py::arg("norm"), "sync_bn backward_param", py::arg("grad_output"), py::arg("norm"),
......
...@@ -89,9 +89,9 @@ void sync_bn_forward_var(const Tensor input, const Tensor mean, Tensor var) { ...@@ -89,9 +89,9 @@ void sync_bn_forward_var(const Tensor input, const Tensor mean, Tensor var) {
} }
void sync_bn_forward_output(const Tensor input, const Tensor mean, void sync_bn_forward_output(const Tensor input, const Tensor mean,
const Tensor var, Tensor running_mean, const Tensor var, const Tensor weight,
Tensor running_var, const Tensor weight, const Tensor bias, Tensor running_mean,
const Tensor bias, Tensor norm, Tensor std, Tensor running_var, Tensor norm, Tensor std,
Tensor output, float eps, float momentum, Tensor output, float eps, float momentum,
int group_size) { int group_size) {
if (input.device().is_cuda()) { if (input.device().is_cuda()) {
...@@ -99,10 +99,10 @@ void sync_bn_forward_output(const Tensor input, const Tensor mean, ...@@ -99,10 +99,10 @@ void sync_bn_forward_output(const Tensor input, const Tensor mean,
CHECK_CUDA_INPUT(input); CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(mean); CHECK_CUDA_INPUT(mean);
CHECK_CUDA_INPUT(var); CHECK_CUDA_INPUT(var);
CHECK_CUDA_INPUT(running_mean);
CHECK_CUDA_INPUT(running_var);
CHECK_CUDA_INPUT(weight); CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(bias); CHECK_CUDA_INPUT(bias);
CHECK_CUDA_INPUT(running_mean);
CHECK_CUDA_INPUT(running_var);
CHECK_CUDA_INPUT(norm); CHECK_CUDA_INPUT(norm);
CHECK_CUDA_INPUT(std); CHECK_CUDA_INPUT(std);
CHECK_CUDA_INPUT(output); CHECK_CUDA_INPUT(output);
......
...@@ -52,14 +52,10 @@ class SyncBatchNormFunction(Function): ...@@ -52,14 +52,10 @@ class SyncBatchNormFunction(Function):
input3d.size(1), dtype=torch.float, device=input3d.device) input3d.size(1), dtype=torch.float, device=input3d.device)
var = torch.empty( var = torch.empty(
input3d.size(1), dtype=torch.float, device=input3d.device) input3d.size(1), dtype=torch.float, device=input3d.device)
if input3d.requires_grad or weight.requires_grad or bias.requires_grad: norm = torch.empty_like(
norm = torch.empty_like( input3d, dtype=torch.float, device=input3d.device)
input3d, dtype=torch.float, device=input3d.device) std = torch.empty(
std = torch.empty( input3d.size(1), dtype=torch.float, device=input3d.device)
input3d.size(1), dtype=torch.float, device=input3d.device)
else:
norm = torch.empty(0, dtype=torch.float, device=input3d.device)
std = torch.empty(0, dtype=torch.float, device=input3d.device)
ext_module.sync_bn_forward_mean(input3d, mean) ext_module.sync_bn_forward_mean(input3d, mean)
if self.group_size > 1: if self.group_size > 1:
...@@ -73,10 +69,10 @@ class SyncBatchNormFunction(Function): ...@@ -73,10 +69,10 @@ class SyncBatchNormFunction(Function):
input3d, input3d,
mean, mean,
var, var,
running_mean,
running_var,
weight, weight,
bias, bias,
running_mean,
running_var,
norm, norm,
std, std,
output3d, output3d,
......
...@@ -21,13 +21,13 @@ class TestSyncBN(object): ...@@ -21,13 +21,13 @@ class TestSyncBN(object):
node_list = str(os.environ['SLURM_NODELIST']) node_list = str(os.environ['SLURM_NODELIST'])
node_parts = re.findall('[0-9]+', node_list) node_parts = re.findall('[0-9]+', node_list)
host_ip = '{}.{}.{}.{}'.format(node_parts[1], node_parts[2], os.environ['MASTER_ADDR'] = (f'{node_parts[1]}.{node_parts[2]}' +
node_parts[3], node_parts[4]) f'.{node_parts[3]}.{node_parts[4]}')
port = '12341' os.environ['MASTER_PORT'] = '12341'
init_method = 'tcp://{}:{}'.format(host_ip, port) os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)
dist.init_process_group( dist.init_process_group('nccl')
'nccl', init_method=init_method, world_size=world_size, rank=rank)
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
def _test_syncbn_train(self, size=1, half=False): def _test_syncbn_train(self, size=1, half=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment