Unverified Commit 2400eb4c authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix some CI torch device issues for PyTorch 1.13 (#19681)



* fix some device issues for pt 1.13

* Update src/transformers/models/ctrl/modeling_ctrl.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 2add2007
...@@ -456,7 +456,9 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -456,7 +456,9 @@ class CTRLModel(CTRLPreTrainedModel):
inputs_embeds *= np.sqrt(self.d_model_size) inputs_embeds *= np.sqrt(self.d_model_size)
pos_embeds = self.pos_encoding[position_ids, :].to(device) # `self.pos_encoding` won't be sent to the correct device along the model, so we do it manually.
self.pos_encoding = self.pos_encoding.to(device)
pos_embeds = self.pos_encoding[position_ids, :]
hidden_states = inputs_embeds + pos_embeds + token_type_embeds hidden_states = inputs_embeds + pos_embeds + token_type_embeds
......
...@@ -136,9 +136,10 @@ class ViltEmbeddings(nn.Module): ...@@ -136,9 +136,10 @@ class ViltEmbeddings(nn.Module):
pos_embed = pos_embed.flatten(2).transpose(1, 2) pos_embed = pos_embed.flatten(2).transpose(1, 2)
x = x.flatten(2).transpose(1, 2) x = x.flatten(2).transpose(1, 2)
# Set `device` here, otherwise `patch_index` will always be on `CPU` and will fail near the end for torch>=1.13
patch_index = torch.stack( patch_index = torch.stack(
torch.meshgrid(torch.arange(x_mask.shape[-2]), torch.arange(x_mask.shape[-1]), indexing="ij"), dim=-1 torch.meshgrid(torch.arange(x_mask.shape[-2]), torch.arange(x_mask.shape[-1]), indexing="ij"), dim=-1
) ).to(device=x_mask.device)
patch_index = patch_index[None, None, :, :, :] patch_index = patch_index[None, None, :, :, :]
patch_index = patch_index.expand(x_mask.shape[0], x_mask.shape[1], -1, -1, -1) patch_index = patch_index.expand(x_mask.shape[0], x_mask.shape[1], -1, -1, -1)
patch_index = patch_index.flatten(1, 3) patch_index = patch_index.flatten(1, 3)
...@@ -177,6 +178,7 @@ class ViltEmbeddings(nn.Module): ...@@ -177,6 +178,7 @@ class ViltEmbeddings(nn.Module):
select = torch.cat(select, dim=0) select = torch.cat(select, dim=0)
x = x[select[:, 0], select[:, 1]].view(batch_size, -1, num_channels) x = x[select[:, 0], select[:, 1]].view(batch_size, -1, num_channels)
x_mask = x_mask[select[:, 0], select[:, 1]].view(batch_size, -1) x_mask = x_mask[select[:, 0], select[:, 1]].view(batch_size, -1)
# `patch_index` should be on the same device as `select` (for torch>=1.13), which is ensured at definition time.
patch_index = patch_index[select[:, 0], select[:, 1]].view(batch_size, -1, 2) patch_index = patch_index[select[:, 0], select[:, 1]].view(batch_size, -1, 2)
pos_embed = pos_embed[select[:, 0], select[:, 1]].view(batch_size, -1, num_channels) pos_embed = pos_embed[select[:, 0], select[:, 1]].view(batch_size, -1, num_channels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment