Commit 47da14a0 authored by Michael Carilli's avatar Michael Carilli
Browse files

cosmetic

parent 8a32e428
...@@ -231,7 +231,7 @@ class DistributedDataParallel(Module): ...@@ -231,7 +231,7 @@ class DistributedDataParallel(Module):
self.module = module self.module = module
self._disable_allreduce = False self._disable_allreduce = False
if self._backend == self.backend_enum_holder.NCCL: if self._backend == self.backend_enum_holder.NCCL:
for param in self.module.parameters(): for param in self.module.parameters():
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU." assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
...@@ -277,9 +277,9 @@ class DistributedDataParallel(Module): ...@@ -277,9 +277,9 @@ class DistributedDataParallel(Module):
def disable_allreduce(self): def disable_allreduce(self):
self._disable_allreduce = True self._disable_allreduce = True
# Broadcast rank 0's bucket structure across all processes, and have all processes # Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match. # regenerate their bucket structures to match.
def sync_bucket_structure(self): def sync_bucket_structure(self):
# Append leftover buckets # Append leftover buckets
for tmp_bucket in self.tmp_buckets: for tmp_bucket in self.tmp_buckets:
...@@ -356,7 +356,6 @@ class DistributedDataParallel(Module): ...@@ -356,7 +356,6 @@ class DistributedDataParallel(Module):
grad_acc = param_tmp.grad_fn.next_functions[0][0] grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused): def allreduce_hook(*unused):
if self.prof: if self.prof:
torch.cuda.nvtx.range_push("allreduce_hook") torch.cuda.nvtx.range_push("allreduce_hook")
...@@ -371,8 +370,8 @@ class DistributedDataParallel(Module): ...@@ -371,8 +370,8 @@ class DistributedDataParallel(Module):
# Float, half, and double tensors are grouped into buckets separately. # Float, half, and double tensors are grouped into buckets separately.
current_type = self.param_type_to_tmp_i[param.type()] current_type = self.param_type_to_tmp_i[param.type()]
self.tmp_buckets[current_type].append(active_i) self.tmp_buckets[current_type].append(active_i)
ship_tmp_bucket = False ship_tmp_bucket = False
if self.custom_allreduce_triggers: if self.custom_allreduce_triggers:
...@@ -389,20 +388,20 @@ class DistributedDataParallel(Module): ...@@ -389,20 +388,20 @@ class DistributedDataParallel(Module):
self.active_i_buckets.append(self.tmp_buckets[current_type]) self.active_i_buckets.append(self.tmp_buckets[current_type])
self.tmp_buckets[current_type] = [] self.tmp_buckets[current_type] = []
self.tmp_numels[current_type] = 0 self.tmp_numels[current_type] = 0
if not self.callback_queued: if not self.callback_queued:
Variable._execution_engine.queue_callback(allreduce_params) Variable._execution_engine.queue_callback(allreduce_params)
self.callback_queued = True self.callback_queued = True
else: else:
if not self.callback_queued: if not self.callback_queued:
Variable._execution_engine.queue_callback(overlapping_backward_epilogue) Variable._execution_engine.queue_callback(overlapping_backward_epilogue)
self.callback_queued = True self.callback_queued = True
self.comm_ready_buckets(param) self.comm_ready_buckets(param)
if self.prof: if self.prof:
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
grad_acc.register_hook(allreduce_hook) grad_acc.register_hook(allreduce_hook)
self.grad_accs.append(grad_acc) self.grad_accs.append(grad_acc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment