Unverified Commit 0578a913 authored by guhur's avatar guhur Committed by GitHub
Browse files

fix nn.DataParallel compatibility with PyTorch 1.5 (#7671)

The same type of errors as in https://github.com/huggingface/transformers/pull/4300
parent 297233fa
......@@ -964,7 +964,7 @@ class LxmertModel(LxmertPreTrainedModel):
# Process the visual attention mask
if visual_attention_mask is not None:
extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2)
extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=next(self.parameters()).dtype)
extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=self.dtype)
extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
else:
extended_visual_attention_mask = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment