"...html/git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "ceb47f1d7678d2f155144abd3e0eefb24684d35e"
Unverified Commit 80479eed authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

More stringent check for parameter changes to trigger refresh of distributed (#20)

* More stringent check for distributed refresh
parent a7319cee
......@@ -428,9 +428,8 @@ class FP16_Optimizer(object):
while(self.overflow):
scale = self.loss_scaler.loss_scale
self._update_scale(self.overflow)
if self.overflow:
print("OVERFLOW within closure! Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(scale, self.loss_scale))
print("OVERFLOW within closure! Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(scale, self.loss_scale))
temp_loss = closure()
return temp_loss
......
......@@ -198,12 +198,13 @@ class DistributedDataParallel(Module):
#Force needs_refresh to True if there are shared params
#this will force it to always, only call flush_buckets which is safe
#for shared parameters in the model.
if not self.param_refs or self.shared_param:
#Parentheses are not necessary for correct order of operations, but make the intent clearer.
if (not self.param_refs) or self.shared_param:
self.needs_refresh = True
else:
self.needs_refresh = any(
[param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]
)
self.needs_refresh = (
(len(param_list) != len(self.param_refs)) or any(
[param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]))
if self.needs_refresh:
self.record = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment