"docs/source/vscode:/vscode.git/clone" did not exist on "057e1d74733f52817dc05b673a340b4e3ebea08c"
Unverified Commit fa6d12f7 authored by Lucas Thompson's avatar Lucas Thompson Committed by GitHub
Browse files

Allow to train dinov2 with different dtypes like bf16 (#28504)

I want to train dinov2 with bf16 but I get the following error in https://github.com/huggingface/transformers/blob/bc72b4e2cdcbc80d5f56731f35dbc9c18b4c8de6/src/transformers/models/dinov2/modeling_dinov2.py#L635:

```
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same
```

Since the input dtype is torch.float32, the parameter dtype has to be torch.float32...

@LZHgrla and I checked the code of clip vision encoder and found there is an automatic dtype transformation (https://github.com/huggingface/transformers/blob/bc72b4e2cdcbc80d5f56731f35dbc9c18b4c8de6/src/transformers/models/clip/modeling_clip.py#L181-L182).

So I add similar automatic dtype transformation to modeling_dinov2.py.
parent 2c1eebc1
......@@ -103,12 +103,13 @@ class Dinov2Embeddings(nn.Module):
height, width = height + 0.1, width + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
target_dtype = patch_pos_embed.dtype
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
patch_pos_embed.to(dtype=torch.float32),
scale_factor=(float(height / math.sqrt(num_positions)), float(width / math.sqrt(num_positions))),
mode="bicubic",
align_corners=False,
)
).to(dtype=target_dtype)
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
......@@ -116,7 +117,8 @@ class Dinov2Embeddings(nn.Module):
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values)
target_dtype = self.patch_embeddings.projection.weight.dtype
embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
if bool_masked_pos is not None:
embeddings = torch.where(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment