Unverified Commit 44caf4f6 authored by layjain's avatar layjain Committed by GitHub
Browse files

Fixed num_channels!=3 normalization training (#20630)

* Fixed num_channels!=3 normalization training

* empty commit to trigger CI

* Empty-Commit for CircleCI

* Empty-Commit

* Empty Commit try-3: https://discuss.circleci.com/t/github-code-checkout-suddenly-failing/31558



* Empty commit to trigger CI
Co-authored-by: default avatarLay Jain <layjain@basil.csail.mit.edu>
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 865da84a
......@@ -819,11 +819,15 @@ class VideoMAEForPreTraining(VideoMAEPreTrainedModel):
loss = None
with torch.no_grad():
# calculate the labels to be predicted
# first, unnormalize the frames
device = pixel_values.device
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :, None, None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :, None, None]
frames = pixel_values * std + mean # in [0, 1]
if self.config.num_channels != 3:
# Can't unnormalize with default means/stds
frames = pixel_values
else:
# first, unnormalize the frames
device = pixel_values.device
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :, None, None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :, None, None]
frames = pixel_values * std + mean # in [0, 1]
batch_size, time, num_channels, height, width = frames.shape
tubelet_size, patch_size = self.config.tubelet_size, self.config.patch_size
......@@ -859,6 +863,10 @@ class VideoMAEForPreTraining(VideoMAEPreTrainedModel):
tubelet_size * patch_size * patch_size * num_channels,
)
else:
if self.config.num_channels != 3:
raise ValueError(
"Can't unnormalize non-RGB images. Consider setting config.norm_pix_loss to False."
)
# step 1: split up dimensions (time by tubelet_size, height by patch_size, width by patch_size)
frames = frames.view(
batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment