Commit 2f3a4210 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Fix other PyTorch models

parent d5319793
...@@ -309,10 +309,12 @@ class XxxModel(XxxPreTrainedModel): ...@@ -309,10 +309,12 @@ class XxxModel(XxxPreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape) attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
......
...@@ -450,8 +450,10 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -450,8 +450,10 @@ class DistilBertModel(DistilBertPreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape) # (bs, seq_length) attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment