Unverified Commit 5a34d8d9 authored by e's avatar e Committed by GitHub
Browse files

move device statements outside if statements (#11292)

parent d9c62047
......@@ -394,13 +394,14 @@ class CTRLModel(CTRLPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
......@@ -438,11 +439,11 @@ class CTRLModel(CTRLPreTrainedModel):
inputs_embeds = self.w(input_ids)
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len = input_shape[-1]
mask = torch.triu(torch.ones(seq_len + past_length, seq_len + past_length), 1).to(inputs_embeds.device)
mask = torch.triu(torch.ones(seq_len + past_length, seq_len + past_length), 1).to(device)
inputs_embeds *= np.sqrt(self.d_model_size)
pos_embeds = self.pos_encoding[position_ids, :].to(inputs_embeds.device)
pos_embeds = self.pos_encoding[position_ids, :].to(device)
hidden_states = inputs_embeds + pos_embeds + token_type_embeds
......
......@@ -675,6 +675,8 @@ class GPT2Model(GPT2PreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
......@@ -686,7 +688,6 @@ class GPT2Model(GPT2PreTrainedModel):
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
......
......@@ -755,6 +755,8 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
......@@ -766,7 +768,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment