Unverified Commit a00952bc authored by jjsjann123's avatar jjsjann123 Committed by GitHub
Browse files

Merge pull request #391 from NVIDIA/persistent_sync_bn_group8_fix

Fixing rank mapping for bn_group size == 8
parents 1483f22d 89ae9e54
......@@ -184,8 +184,8 @@ class BatchNorm2d_NHWC(_BatchNorm):
self.pair_data2 = bnp.get_remote_data_ptr(self.pair_handle2, pair_offset2)
if bn_group>4:
self.pair_handle3 = handles_l[local_rank ^ 3].cpu().contiguous()
pair_offset3 = offsets_l[local_rank ^ 3].cpu()
self.pair_handle3 = handles_l[local_rank ^ 4].cpu().contiguous()
pair_offset3 = offsets_l[local_rank ^ 4].cpu()
self.pair_data3 = bnp.get_remote_data_ptr(self.pair_handle3, pair_offset3)
#FIXME: get magic value into C code and eliminate from here
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment