Commit d5319793 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Fix BERT

parent 27e015bd
......@@ -170,7 +170,7 @@ class BertEmbeddings(nn.Module):
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
......@@ -655,11 +655,11 @@ class BertModel(BertPreTrainedModel):
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape)
attention_mask = torch.ones(input_shape, device=device)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(input_shape)
encoder_attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment