Unverified Commit 4baa34c1 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Flava`] Fix flava `torch.distributed.nn.functional import all_gather` issue (#23108)

* fix flava `torch.distributed.nn.functional import all_gather` issue

* more comments
parent c6c66584
......@@ -1693,8 +1693,10 @@ class FlavaGlobalContrastiveHead(nn.Module):
world_size = torch.distributed.get_world_size()
if self.global_backprop_contrastive:
image_embeddings_all = torch.distributed.nn.functional.all_gather_with_backprop(image_embeddings)
text_embeddings_all = torch.distributed.nn.functional.all_gather_with_backprop(text_embeddings)
# `torch.distributed.nn.functional.all_gather` does backprop on all active workers
# whereas `torch.distributed.all_gather` does only backpropagates on the current worker.
image_embeddings_all = torch.distributed.nn.functional.all_gather(image_embeddings)
text_embeddings_all = torch.distributed.nn.functional.all_gather(text_embeddings)
else:
image_embeddings_all = [torch.zeros_like(text_embeddings) for _ in range(world_size)]
text_embeddings_all = [torch.zeros_like(image_embeddings) for _ in range(world_size)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment