Commit be4c41c2 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix

parent db8fb976
......@@ -338,7 +338,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
@property
def L2_grad_norm(self):
return self._L2_grad_norm
if self._compute_L2_grad_norm:
for i, blk_st in enumerate(self._blk_st):
torch.cuda.current_stream().wait_stream(blk_st)
return self._L2_grad_norm
else:
return None
# Distributed weight update algorithm:
# Model parameters are kept as-is.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment