"...resnet50_tensorflow.git" did not exist on "283a0015c4491a2d6c8e325115ff3dcb01ee49fb"
Commit 4d181999 authored by Rémi Louf's avatar Rémi Louf Committed by Julien Chaumond
Browse files

cast bool tensor to long for pytorch < 1.3

parent 9f75565e
...@@ -675,6 +675,7 @@ class BertModel(BertPreTrainedModel): ...@@ -675,6 +675,7 @@ class BertModel(BertPreTrainedModel):
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device) seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
causal_mask = causal_mask.to(torch.long) # not converting to long will cause errors with pytorch version < 1.3
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else: else:
extended_attention_mask = attention_mask[:, None, None, :] extended_attention_mask = attention_mask[:, None, None, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment