Unverified Commit 5b40a37b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Add TF ViT MAE (#16255)



* ported TFViTMAEIntermediate and TFViTMAEOutput.

* added TFViTMAEModel and TFViTMAEDecoder.

* feat: added a noise argument in the implementation for reproducibility.

* feat: vit mae models with an additional noise argument for reproducibility.
Co-authored-by: default avatarariG23498 <aritra.born2fly@gmail.com>
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 7a9ef818
...@@ -260,7 +260,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -260,7 +260,7 @@ Flax), PyTorch, and/or TensorFlow.
| VisionTextDualEncoder | ❌ | ❌ | ✅ | ❌ | ✅ | | VisionTextDualEncoder | ❌ | ❌ | ✅ | ❌ | ✅ |
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ | | VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
| ViT | ❌ | ❌ | ✅ | ✅ | ✅ | | ViT | ❌ | ❌ | ✅ | ✅ | ✅ |
| ViTMAE | ❌ | ❌ | ✅ | | ❌ | | ViTMAE | ❌ | ❌ | ✅ | | ❌ |
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ | | Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ | | WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
| XGLM | ✅ | ✅ | ✅ | ❌ | ✅ | | XGLM | ✅ | ✅ | ✅ | ❌ | ✅ |
......
...@@ -41,13 +41,16 @@ fine-tuning, one can directly plug in the weights into a [`ViTForImageClassifica ...@@ -41,13 +41,16 @@ fine-tuning, one can directly plug in the weights into a [`ViTForImageClassifica
- Note that the encoder of MAE is only used to encode the visual patches. The encoded patches are then concatenated with mask tokens, which the decoder (which also - Note that the encoder of MAE is only used to encode the visual patches. The encoded patches are then concatenated with mask tokens, which the decoder (which also
consists of Transformer blocks) takes as input. Each mask token is a shared, learned vector that indicates the presence of a missing patch to be predicted. Fixed consists of Transformer blocks) takes as input. Each mask token is a shared, learned vector that indicates the presence of a missing patch to be predicted. Fixed
sin/cos position embeddings are added both to the input of the encoder and the decoder. sin/cos position embeddings are added both to the input of the encoder and the decoder.
- For a visual understanding of how MAEs work you can check out this [post](https://keras.io/examples/vision/masked_image_modeling/).
<img src="https://user-images.githubusercontent.com/11435359/146857310-f258c86c-fde6-48e8-9cee-badd2b21bd2c.png" <img src="https://user-images.githubusercontent.com/11435359/146857310-f258c86c-fde6-48e8-9cee-badd2b21bd2c.png"
alt="drawing" width="600"/> alt="drawing" width="600"/>
<small> MAE architecture. Taken from the <a href="https://arxiv.org/abs/2111.06377">original paper.</a> </small> <small> MAE architecture. Taken from the <a href="https://arxiv.org/abs/2111.06377">original paper.</a> </small>
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/facebookresearch/mae). This model was contributed by [nielsr](https://huggingface.co/nielsr). TensorFlow version of the model was contributed by [sayakpaul](https://github.com/sayakpaul) and
[ariG23498](https://github.com/ariG23498) (equal contribution). The original code can be found [here](https://github.com/facebookresearch/mae).
## ViTMAEConfig ## ViTMAEConfig
...@@ -64,3 +67,15 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi ...@@ -64,3 +67,15 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi
[[autodoc]] transformers.ViTMAEForPreTraining [[autodoc]] transformers.ViTMAEForPreTraining
- forward - forward
## TFViTMAEModel
[[autodoc]] TFViTMAEModel
- call
## TFViTMAEForPreTraining
[[autodoc]] transformers.TFViTMAEForPreTraining
- call
...@@ -2135,6 +2135,13 @@ if is_tf_available(): ...@@ -2135,6 +2135,13 @@ if is_tf_available():
"TFViTPreTrainedModel", "TFViTPreTrainedModel",
] ]
) )
_import_structure["models.vit_mae"].extend(
[
"TFViTMAEForPreTraining",
"TFViTMAEModel",
"TFViTMAEPreTrainedModel",
]
)
_import_structure["models.wav2vec2"].extend( _import_structure["models.wav2vec2"].extend(
[ [
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -4170,6 +4177,7 @@ if TYPE_CHECKING: ...@@ -4170,6 +4177,7 @@ if TYPE_CHECKING:
) )
from .models.vision_encoder_decoder import TFVisionEncoderDecoderModel from .models.vision_encoder_decoder import TFVisionEncoderDecoderModel
from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
from .models.vit_mae import TFViTMAEForPreTraining, TFViTMAEModel, TFViTMAEPreTrainedModel
from .models.wav2vec2 import ( from .models.wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC, TFWav2Vec2ForCTC,
......
...@@ -70,6 +70,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -70,6 +70,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("blenderbot", "TFBlenderbotModel"), ("blenderbot", "TFBlenderbotModel"),
("blenderbot-small", "TFBlenderbotSmallModel"), ("blenderbot-small", "TFBlenderbotSmallModel"),
("vit", "TFViTModel"), ("vit", "TFViTModel"),
("vit_mae", "TFViTMAEModel"),
("wav2vec2", "TFWav2Vec2Model"), ("wav2vec2", "TFWav2Vec2Model"),
("hubert", "TFHubertModel"), ("hubert", "TFHubertModel"),
] ]
...@@ -100,6 +101,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -100,6 +101,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("tapas", "TFTapasForMaskedLM"), ("tapas", "TFTapasForMaskedLM"),
("funnel", "TFFunnelForPreTraining"), ("funnel", "TFFunnelForPreTraining"),
("mpnet", "TFMPNetForMaskedLM"), ("mpnet", "TFMPNetForMaskedLM"),
("vit_mae", "TFViTMAEForPreTraining"),
] ]
) )
......
...@@ -33,6 +33,12 @@ if is_torch_available(): ...@@ -33,6 +33,12 @@ if is_torch_available():
"ViTMAEPreTrainedModel", "ViTMAEPreTrainedModel",
] ]
if is_tf_available():
_import_structure["modeling_tf_vit_mae"] = [
"TFViTMAEForPreTraining",
"TFViTMAEModel",
"TFViTMAEPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig from .configuration_vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig
...@@ -46,6 +52,9 @@ if TYPE_CHECKING: ...@@ -46,6 +52,9 @@ if TYPE_CHECKING:
ViTMAEPreTrainedModel, ViTMAEPreTrainedModel,
) )
if is_tf_available():
from .modeling_tf_vit_mae import TFViTMAEForPreTraining, TFViTMAEModel, TFViTMAEPreTrainedModel
else: else:
import sys import sys
......
This diff is collapsed.
...@@ -240,18 +240,21 @@ class ViTMAEEmbeddings(nn.Module): ...@@ -240,18 +240,21 @@ class ViTMAEEmbeddings(nn.Module):
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range) torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
def random_masking(self, sequence): def random_masking(self, sequence, noise=None):
""" """
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
noise. noise.
Args: Args:
sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`) sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`)
noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
mainly used for testing purposes to control randomness and maintain the reproducibility
""" """
batch_size, seq_length, dim = sequence.shape batch_size, seq_length, dim = sequence.shape
len_keep = int(seq_length * (1 - self.config.mask_ratio)) len_keep = int(seq_length * (1 - self.config.mask_ratio))
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1] if noise is None:
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
# sort noise for each sample # sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
...@@ -269,7 +272,7 @@ class ViTMAEEmbeddings(nn.Module): ...@@ -269,7 +272,7 @@ class ViTMAEEmbeddings(nn.Module):
return sequence_masked, mask, ids_restore return sequence_masked, mask, ids_restore
def forward(self, pixel_values): def forward(self, pixel_values, noise=None):
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values) embeddings = self.patch_embeddings(pixel_values)
...@@ -277,7 +280,7 @@ class ViTMAEEmbeddings(nn.Module): ...@@ -277,7 +280,7 @@ class ViTMAEEmbeddings(nn.Module):
embeddings = embeddings + self.position_embeddings[:, 1:, :] embeddings = embeddings + self.position_embeddings[:, 1:, :]
# masking: length -> length * config.mask_ratio # masking: length -> length * config.mask_ratio
embeddings, mask, ids_restore = self.random_masking(embeddings) embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
# append cls token # append cls token
cls_token = self.cls_token + self.position_embeddings[:, :1, :] cls_token = self.cls_token + self.position_embeddings[:, :1, :]
...@@ -668,6 +671,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel): ...@@ -668,6 +671,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel):
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
noise=None,
head_mask=None, head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
...@@ -709,7 +713,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel): ...@@ -709,7 +713,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output, mask, ids_restore = self.embeddings(pixel_values) embedding_output, mask, ids_restore = self.embeddings(pixel_values, noise=noise)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
...@@ -910,6 +914,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): ...@@ -910,6 +914,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
noise=None,
head_mask=None, head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
...@@ -941,6 +946,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): ...@@ -941,6 +946,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
outputs = self.vit( outputs = self.vit(
pixel_values, pixel_values,
noise=noise,
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
......
...@@ -1987,6 +1987,27 @@ class TFViTPreTrainedModel(metaclass=DummyObject): ...@@ -1987,6 +1987,27 @@ class TFViTPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFViTMAEForPreTraining(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFViTMAEModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFViTMAEPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
This diff is collapsed.
...@@ -17,13 +17,14 @@ ...@@ -17,13 +17,14 @@
import inspect import inspect
import math import math
import os
import tempfile import tempfile
import unittest import unittest
import numpy as np import numpy as np
from transformers import ViTMAEConfig from transformers import ViTMAEConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import is_pt_tf_cross_test, require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available from transformers.utils import cached_property, is_torch_available, is_vision_available
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
...@@ -139,11 +140,7 @@ class ViTMAEModelTester: ...@@ -139,11 +140,7 @@ class ViTMAEModelTester:
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( config, pixel_values, labels = config_and_inputs
config,
pixel_values,
labels,
) = config_and_inputs
inputs_dict = {"pixel_values": pixel_values} inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict return config, inputs_dict
...@@ -322,6 +319,153 @@ class ViTMAEModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -322,6 +319,153 @@ class ViTMAEModelTest(ModelTesterMixin, unittest.TestCase):
check_hidden_states_output(inputs_dict, config, model_class) check_hidden_states_output(inputs_dict, config, model_class)
# overwrite from common since ViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
import numpy as np
import tensorflow as tf
import transformers
# make masks reproducible
np.random.seed(2)
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
num_patches = int((config.image_size // config.patch_size) ** 2)
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
pt_noise = torch.from_numpy(noise).to(device=torch_device)
tf_noise = tf.constant(noise)
def prepare_tf_inputs_from_pt_inputs(pt_inputs_dict):
tf_inputs_dict = {}
for key, tensor in pt_inputs_dict.items():
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
return tf_inputs_dict
def check_outputs(tf_outputs, pt_outputs, model_class, names):
"""
Args:
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make
debugging easier and faster.
names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs.
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF.
"""
# Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors.
if type(tf_outputs) in [tuple, list]:
self.assertEqual(type(tf_outputs), type(pt_outputs))
self.assertEqual(len(tf_outputs), len(pt_outputs))
if type(names) == tuple:
for tf_output, pt_output, name in zip(tf_outputs, pt_outputs, names):
check_outputs(tf_output, pt_output, model_class, names=name)
elif type(names) == str:
for idx, (tf_output, pt_output) in enumerate(zip(tf_outputs, pt_outputs)):
check_outputs(tf_output, pt_output, model_class, names=f"{names}_{idx}")
else:
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
elif isinstance(tf_outputs, tf.Tensor):
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
tf_outputs = tf_outputs.numpy()
if isinstance(tf_outputs, np.float32):
tf_outputs = np.array(tf_outputs, dtype=np.float32)
pt_outputs = pt_outputs.detach().to("cpu").numpy()
tf_nans = np.isnan(tf_outputs)
pt_nans = np.isnan(pt_outputs)
pt_outputs[tf_nans] = 0
tf_outputs[tf_nans] = 0
pt_outputs[pt_nans] = 0
tf_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
self.assertLessEqual(max_diff, 1e-5)
else:
raise ValueError(
f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
)
def check_pt_tf_models(tf_model, pt_model, pt_inputs_dict):
# we are not preparing a model with labels because of the formation
# of the ViT MAE model
# send pytorch model to the correct device
pt_model.to(torch_device)
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
pt_model.eval()
tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
# send pytorch inputs to the correct device
pt_inputs_dict = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
}
# Original test: check without `labels`
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs_dict, noise=pt_noise)
tf_outputs = tf_model(tf_inputs_dict, noise=tf_noise)
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(tf_keys, pt_keys)
check_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=tf_keys)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
tf_model_class = getattr(transformers, tf_model_class_name)
tf_model = tf_model_class(config)
pt_model = model_class(config)
# make sure only tf inputs are forward that actually exist in function args
tf_input_keys = set(inspect.signature(tf_model.call).parameters.keys())
# remove all head masks
tf_input_keys.discard("head_mask")
tf_input_keys.discard("cross_attn_head_mask")
tf_input_keys.discard("decoder_head_mask")
pt_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs_dict = {k: v for k, v in pt_inputs_dict.items() if k in tf_input_keys}
# Check we can load pt model in tf and vice-versa with model => model functions
tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
pt_model = pt_model.to(torch_device)
check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
def test_save_load(self): def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -400,11 +544,8 @@ class ViTMAEModelIntegrationTest(unittest.TestCase): ...@@ -400,11 +544,8 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
@slow @slow
def test_inference_for_pretraining(self): def test_inference_for_pretraining(self):
# make random mask reproducible # make random mask reproducible across the PT and TF model
# note that the same seed on CPU and on GPU doesn’t mean they spew the same random number sequences, np.random.seed(2)
# as they both have fairly different PRNGs (for efficiency reasons).
# source: https://discuss.pytorch.org/t/random-seed-that-spans-across-devices/19735
torch.manual_seed(2)
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device) model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
...@@ -412,22 +553,22 @@ class ViTMAEModelIntegrationTest(unittest.TestCase): ...@@ -412,22 +553,22 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
image = prepare_img() image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
# prepare a noise vector that will be also used for testing the TF model
# (this way we can ensure that the PT and TF models operate on the same inputs)
vit_mae_config = ViTMAEConfig()
num_patches = int((vit_mae_config.image_size // vit_mae_config.patch_size) ** 2)
noise = np.random.uniform(size=(1, num_patches))
# forward pass # forward pass
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs, noise=torch.from_numpy(noise))
# verify the logits # verify the logits
expected_shape = torch.Size((1, 196, 768)) expected_shape = torch.Size((1, 196, 768))
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice_cpu = torch.tensor( expected_slice = torch.tensor(
[[0.7366, -1.3663, -0.2844], [0.7919, -1.3839, -0.3241], [0.4313, -0.7168, -0.2878]] [[-0.0548, -1.7023, -0.9325], [0.3721, -0.5670, -0.2233], [0.8235, -1.3878, -0.3524]]
) )
expected_slice_gpu = torch.tensor(
[[0.8948, -1.0680, 0.0030], [0.9758, -1.1181, -0.0290], [1.0602, -1.1522, -0.0528]]
)
# set expected slice depending on device
expected_slice = expected_slice_cpu if torch_device == "cpu" else expected_slice_gpu
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment