Unverified Commit a6e0d5a2 authored by Iker García-Ferrero's avatar Iker García-Ferrero Committed by GitHub
Browse files

Fix VideoMAEforPretrained dtype error (#27296)

* Fix dtype error

* Fix mean and std dtype

* make style
parent e9dbd392
......@@ -848,8 +848,9 @@ class VideoMAEForPreTraining(VideoMAEPreTrainedModel):
else:
# first, unnormalize the frames
device = pixel_values.device
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :, None, None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :, None, None]
dtype = pixel_values.dtype
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
frames = pixel_values * std + mean # in [0, 1]
batch_size, time, num_channels, height, width = frames.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment