Unverified Commit 4ef930c1 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Should pass stricter stride/size checks in pytorch (#942)

parent 5d9b5cbc
...@@ -33,11 +33,11 @@ class SyncBatchnormFunction(Function): ...@@ -33,11 +33,11 @@ class SyncBatchnormFunction(Function):
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=device) mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=device)
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=device) var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=device)
count_all = torch.cuda.IntTensor(world_size, device=device) count_all = torch.cuda.IntTensor(world_size, device=device)
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)] mean_l = [mean_all.narrow(0, i, 1).view(-1) for i in range(world_size)]
var_l = [var_all.narrow(0, i, 1) for i in range(world_size)] var_l = [var_all.narrow(0, i, 1).view(-1) for i in range(world_size)]
count_l = [count_all.narrow(0, i, 1) for i in range(world_size)] count_l = [count_all.narrow(0, i, 1) for i in range(world_size)]
torch.distributed.all_gather(mean_l, mean, process_group) torch.distributed.all_gather(mean_l, mean.view(-1), process_group)
torch.distributed.all_gather(var_l, var_biased, process_group) torch.distributed.all_gather(var_l, var_biased.view(-1), process_group)
torch.distributed.all_gather( torch.distributed.all_gather(
count_l, count_l,
torch.cuda.IntTensor([count], device=device), torch.cuda.IntTensor([count], device=device),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment