Commit 4b984a7c authored by CapMocha's avatar CapMocha Committed by Kai Chen
Browse files

bug for distributed training (#1985)

fix a bug for distributed training in windows platform
parent e032ebb7
...@@ -73,7 +73,7 @@ def parse_losses(losses): ...@@ -73,7 +73,7 @@ def parse_losses(losses):
log_vars['loss'] = loss log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items(): for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training # reduce loss when distributed training
if dist.is_initialized(): if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone() loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size())) dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item() log_vars[loss_name] = loss_value.item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment