"...text-generation-inference.git" did not exist on "10d9083b2d5c47e73d9acb4b6cf64b9fcad6c934"
Unverified Commit 4ef930c1 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Should pass stricter stride/size checks in pytorch (#942)

parent 5d9b5cbc
......@@ -33,11 +33,11 @@ class SyncBatchnormFunction(Function):
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=device)
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=device)
count_all = torch.cuda.IntTensor(world_size, device=device)
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
var_l = [var_all.narrow(0, i, 1) for i in range(world_size)]
mean_l = [mean_all.narrow(0, i, 1).view(-1) for i in range(world_size)]
var_l = [var_all.narrow(0, i, 1).view(-1) for i in range(world_size)]
count_l = [count_all.narrow(0, i, 1) for i in range(world_size)]
torch.distributed.all_gather(mean_l, mean, process_group)
torch.distributed.all_gather(var_l, var_biased, process_group)
torch.distributed.all_gather(mean_l, mean.view(-1), process_group)
torch.distributed.all_gather(var_l, var_biased.view(-1), process_group)
torch.distributed.all_gather(
count_l,
torch.cuda.IntTensor([count], device=device),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment