Commit d5319793 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Fix BERT

parent 27e015bd
...@@ -170,7 +170,7 @@ class BertEmbeddings(nn.Module): ...@@ -170,7 +170,7 @@ class BertEmbeddings(nn.Module):
position_ids = torch.arange(seq_length, dtype=torch.long, device=device) position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape) position_ids = position_ids.unsqueeze(0).expand(input_shape)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
...@@ -655,11 +655,11 @@ class BertModel(BertPreTrainedModel): ...@@ -655,11 +655,11 @@ class BertModel(BertPreTrainedModel):
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape) attention_mask = torch.ones(input_shape, device=device)
if encoder_attention_mask is None: if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(input_shape) encoder_attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment