Unverified Commit 4cbd876e authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

[`Vilt`] align input and model dtype in the ViltPatchEmbeddings forward pass (#28633)

align dtype
parent 24f1a00e
...@@ -317,7 +317,8 @@ class ViltPatchEmbeddings(nn.Module): ...@@ -317,7 +317,8 @@ class ViltPatchEmbeddings(nn.Module):
raise ValueError( raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration." "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
) )
x = self.projection(pixel_values) target_dtype = self.projection.weight.dtype
x = self.projection(pixel_values.to(dtype=target_dtype))
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment