"test/ut/vscode:/vscode.git/clone" did not exist on "84b9c9b24e41e7b2ff51e2f73dd0dede587fec13"
Unverified Commit 5a34d8d9 authored by e's avatar e Committed by GitHub
Browse files

move device statements outside if statements (#11292)

parent d9c62047
...@@ -394,13 +394,14 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -394,13 +394,14 @@ class CTRLModel(CTRLPreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if past_key_values is None: if past_key_values is None:
past_length = 0 past_length = 0
past_key_values = tuple([None] * len(self.h)) past_key_values = tuple([None] * len(self.h))
else: else:
past_length = past_key_values[0][0].size(-2) past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
...@@ -438,11 +439,11 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -438,11 +439,11 @@ class CTRLModel(CTRLPreTrainedModel):
inputs_embeds = self.w(input_ids) inputs_embeds = self.w(input_ids)
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded # inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len = input_shape[-1] seq_len = input_shape[-1]
mask = torch.triu(torch.ones(seq_len + past_length, seq_len + past_length), 1).to(inputs_embeds.device) mask = torch.triu(torch.ones(seq_len + past_length, seq_len + past_length), 1).to(device)
inputs_embeds *= np.sqrt(self.d_model_size) inputs_embeds *= np.sqrt(self.d_model_size)
pos_embeds = self.pos_encoding[position_ids, :].to(inputs_embeds.device) pos_embeds = self.pos_encoding[position_ids, :].to(device)
hidden_states = inputs_embeds + pos_embeds + token_type_embeds hidden_states = inputs_embeds + pos_embeds + token_type_embeds
......
...@@ -675,6 +675,8 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -675,6 +675,8 @@ class GPT2Model(GPT2PreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1]) token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None: if position_ids is not None:
...@@ -686,7 +688,6 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -686,7 +688,6 @@ class GPT2Model(GPT2PreTrainedModel):
else: else:
past_length = past_key_values[0][0].size(-2) past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
......
...@@ -755,6 +755,8 @@ class GPTNeoModel(GPTNeoPreTrainedModel): ...@@ -755,6 +755,8 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1]) token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None: if position_ids is not None:
...@@ -766,7 +768,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel): ...@@ -766,7 +768,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
else: else:
past_length = past_key_values[0][0].size(-2) past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment