Unverified Commit 4cbd876e authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

[`Vilt`] align input and model dtype in the ViltPatchEmbeddings forward pass (#28633)

align dtype
parent 24f1a00e
......@@ -317,7 +317,8 @@ class ViltPatchEmbeddings(nn.Module):
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
x = self.projection(pixel_values)
target_dtype = self.projection.weight.dtype
x = self.projection(pixel_values.to(dtype=target_dtype))
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment