Unverified Commit 57882177 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add SimMIM (#15586)



* Add first draft

* Make model importable

* Make SwinForMaskedImageModeling importable

* Fix imports

* Add missing inits

* Add support for Swin

* Fix bug

* Fix bug

* Fix another bug

* Fix Swin MIM implementation

* Fix default encoder stride

* Fix Swin

* Add print statements for debugging

* Add image_size data argument

* Fix Swin

* Fix image_size

* Add print statements for debugging

* Fix print statement

* Remove print statements

* Improve reshaping of bool_masked_pos

* Add support for DeiT, fix tests

* Improve docstrings

* Apply new black version

* Improve script

* Fix bug

* Improve README

* Apply suggestions from code review

* Remove DS_Store and add to gitignore

* Apply suggestions from code review + fix BEiT Flax

* Revert BEiT changes

* Improve README

* Fix code quality

* Improve README
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MBP.localdomain>
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 426b9623
...@@ -24,8 +24,13 @@ from torch import nn ...@@ -24,8 +24,13 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import (
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging from ...utils import logging
from .configuration_vit import ViTConfig from .configuration_vit import ViTConfig
...@@ -67,14 +72,15 @@ def to_2tuple(x): ...@@ -67,14 +72,15 @@ def to_2tuple(x):
class ViTEmbeddings(nn.Module): class ViTEmbeddings(nn.Module):
""" """
Construct the CLS token, position and patch embeddings. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
""" """
def __init__(self, config): def __init__(self, config, use_mask_token=False):
super().__init__() super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = PatchEmbeddings( self.patch_embeddings = PatchEmbeddings(
image_size=config.image_size, image_size=config.image_size,
patch_size=config.patch_size, patch_size=config.patch_size,
...@@ -117,10 +123,17 @@ class ViTEmbeddings(nn.Module): ...@@ -117,10 +123,17 @@ class ViTEmbeddings(nn.Module):
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, pixel_values, interpolate_pos_encoding=False): def forward(self, pixel_values, bool_masked_pos=None, interpolate_pos_encoding=False):
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
batch_size, seq_len, _ = embeddings.size()
if bool_masked_pos is not None:
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
# replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
# add the [CLS] token to the embedded patch tokens # add the [CLS] token to the embedded patch tokens
cls_tokens = self.cls_token.expand(batch_size, -1, -1) cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1) embeddings = torch.cat((cls_tokens, embeddings), dim=1)
...@@ -422,10 +435,6 @@ class ViTPreTrainedModel(PreTrainedModel): ...@@ -422,10 +435,6 @@ class ViTPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
...@@ -476,11 +485,11 @@ VIT_INPUTS_DOCSTRING = r""" ...@@ -476,11 +485,11 @@ VIT_INPUTS_DOCSTRING = r"""
VIT_START_DOCSTRING, VIT_START_DOCSTRING,
) )
class ViTModel(ViTPreTrainedModel): class ViTModel(ViTPreTrainedModel):
def __init__(self, config, add_pooling_layer=True): def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.embeddings = ViTEmbeddings(config) self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
self.encoder = ViTEncoder(config) self.encoder = ViTEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -512,6 +521,7 @@ class ViTModel(ViTPreTrainedModel): ...@@ -512,6 +521,7 @@ class ViTModel(ViTPreTrainedModel):
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
bool_masked_pos=None,
head_mask=None, head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
...@@ -534,7 +544,9 @@ class ViTModel(ViTPreTrainedModel): ...@@ -534,7 +544,9 @@ class ViTModel(ViTPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
...@@ -573,6 +585,107 @@ class ViTPooler(nn.Module): ...@@ -573,6 +585,107 @@ class ViTPooler(nn.Module):
return pooled_output return pooled_output
@add_start_docstrings(
"ViT Model with a decoder on top for masked image modeling, as proposed in `SimMIM <https://arxiv.org/abs/2111.09886>`__.",
VIT_START_DOCSTRING,
)
class ViTForMaskedImageModeling(ViTPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)
self.decoder = nn.Sequential(
nn.Conv2d(in_channels=config.hidden_size, out_channels=config.encoder_stride**2 * 3, kernel_size=1),
nn.PixelShuffle(config.encoder_stride),
)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
bool_masked_pos=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
interpolate_pos_encoding=None,
return_dict=None,
):
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
Returns:
Examples:
```python
>>> from transformers import ViTFeatureExtractor, ViTForMaskedImageModeling
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vit(
pixel_values,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
sequence_output = outputs[0]
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output[:, 1:]
batch_size, sequence_length, num_channels = sequence_output.shape
height = width = int(sequence_length**0.5)
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
# Reconstruct pixel values
reconstructed_pixel_values = self.decoder(sequence_output)
masked_im_loss = None
if bool_masked_pos is not None:
size = self.config.image_size // self.config.patch_size
bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
mask = (
bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
.repeat_interleave(self.config.patch_size, 2)
.unsqueeze(1)
.contiguous()
)
reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
if not return_dict:
output = (reconstructed_pixel_values,) + outputs[2:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
return MaskedLMOutput(
loss=masked_im_loss,
logits=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings( @add_start_docstrings(
""" """
ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
......
...@@ -362,6 +362,9 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None ...@@ -362,6 +362,9 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = None MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = None
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = None
MODEL_FOR_MASKED_LM_MAPPING = None MODEL_FOR_MASKED_LM_MAPPING = None
...@@ -460,6 +463,13 @@ class AutoModelForImageSegmentation(metaclass=DummyObject): ...@@ -460,6 +463,13 @@ class AutoModelForImageSegmentation(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class AutoModelForMaskedImageModeling(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModelForMaskedLM(metaclass=DummyObject): class AutoModelForMaskedLM(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -1305,6 +1315,13 @@ class DeiTForImageClassificationWithTeacher(metaclass=DummyObject): ...@@ -1305,6 +1315,13 @@ class DeiTForImageClassificationWithTeacher(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class DeiTForMaskedImageModeling(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DeiTModel(metaclass=DummyObject): class DeiTModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -3449,6 +3466,13 @@ class SwinForImageClassification(metaclass=DummyObject): ...@@ -3449,6 +3466,13 @@ class SwinForImageClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class SwinForMaskedImageModeling(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SwinModel(metaclass=DummyObject): class SwinModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -3782,6 +3806,13 @@ class ViTForImageClassification(metaclass=DummyObject): ...@@ -3782,6 +3806,13 @@ class ViTForImageClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class ViTForMaskedImageModeling(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ViTModel(metaclass=DummyObject): class ViTModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -72,6 +72,7 @@ if is_torch_available(): ...@@ -72,6 +72,7 @@ if is_torch_available():
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
...@@ -165,6 +166,11 @@ class ModelTesterMixin: ...@@ -165,6 +166,11 @@ class ModelTesterMixin:
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
) )
elif model_class in get_values(MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING):
num_patches = self.model_tester.image_size // self.model_tester.patch_size
inputs_dict["bool_masked_pos"] = torch.zeros(
(self.model_tester.batch_size, num_patches**2), dtype=torch.long, device=torch_device
)
return inputs_dict return inputs_dict
def test_save_load(self): def test_save_load(self):
......
...@@ -35,6 +35,7 @@ if is_torch_available(): ...@@ -35,6 +35,7 @@ if is_torch_available():
MODEL_MAPPING, MODEL_MAPPING,
DeiTForImageClassification, DeiTForImageClassification,
DeiTForImageClassificationWithTeacher, DeiTForImageClassificationWithTeacher,
DeiTForMaskedImageModeling,
DeiTModel, DeiTModel,
) )
from transformers.models.deit.modeling_deit import DEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple from transformers.models.deit.modeling_deit import DEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
...@@ -67,6 +68,7 @@ class DeiTModelTester: ...@@ -67,6 +68,7 @@ class DeiTModelTester:
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
scope=None, scope=None,
encoder_stride=2,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -85,6 +87,7 @@ class DeiTModelTester: ...@@ -85,6 +87,7 @@ class DeiTModelTester:
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.scope = scope self.scope = scope
self.encoder_stride = encoder_stride
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -111,6 +114,7 @@ class DeiTModelTester: ...@@ -111,6 +114,7 @@ class DeiTModelTester:
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
...@@ -155,6 +159,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -155,6 +159,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
DeiTModel, DeiTModel,
DeiTForImageClassification, DeiTForImageClassification,
DeiTForImageClassificationWithTeacher, DeiTForImageClassificationWithTeacher,
DeiTForMaskedImageModeling,
) )
if is_torch_available() if is_torch_available()
else () else ()
......
...@@ -31,7 +31,7 @@ if is_torch_available(): ...@@ -31,7 +31,7 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import SwinForImageClassification, SwinModel from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel
from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
if is_vision_available(): if is_vision_available():
...@@ -74,6 +74,7 @@ class SwinModelTester: ...@@ -74,6 +74,7 @@ class SwinModelTester:
scope=None, scope=None,
use_labels=True, use_labels=True,
type_sequence_label_size=10, type_sequence_label_size=10,
encoder_stride=2,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -98,6 +99,7 @@ class SwinModelTester: ...@@ -98,6 +99,7 @@ class SwinModelTester:
self.scope = scope self.scope = scope
self.use_labels = use_labels self.use_labels = use_labels
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.encoder_stride = encoder_stride
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -129,6 +131,7 @@ class SwinModelTester: ...@@ -129,6 +131,7 @@ class SwinModelTester:
path_norm=self.patch_norm, path_norm=self.patch_norm,
layer_norm_eps=self.layer_norm_eps, layer_norm_eps=self.layer_norm_eps,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
...@@ -169,6 +172,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -169,6 +172,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
( (
SwinModel, SwinModel,
SwinForImageClassification, SwinForImageClassification,
SwinForMaskedImageModeling,
) )
if is_torch_available() if is_torch_available()
else () else ()
......
...@@ -30,7 +30,7 @@ if is_torch_available(): ...@@ -30,7 +30,7 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import ViTForImageClassification, ViTModel from transformers import ViTForImageClassification, ViTForMaskedImageModeling, ViTModel
from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
...@@ -61,6 +61,7 @@ class ViTModelTester: ...@@ -61,6 +61,7 @@ class ViTModelTester:
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
scope=None, scope=None,
encoder_stride=2,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -79,6 +80,7 @@ class ViTModelTester: ...@@ -79,6 +80,7 @@ class ViTModelTester:
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.scope = scope self.scope = scope
self.encoder_stride = encoder_stride
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -105,6 +107,7 @@ class ViTModelTester: ...@@ -105,6 +107,7 @@ class ViTModelTester:
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
...@@ -148,6 +151,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -148,6 +151,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
( (
ViTModel, ViTModel,
ViTForImageClassification, ViTForImageClassification,
ViTForMaskedImageModeling,
) )
if is_torch_available() if is_torch_available()
else () else ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment