Unverified Commit 9fd584e5 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add copied from statements and fix prefix (#16119)


Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent f284aa32
......@@ -134,9 +134,7 @@ class ViTMAEForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
# Inspired by
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
# From PyTorch internals
# copied from transformers.models.vit.modeling_vit.to_2tuple
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
......@@ -318,6 +316,7 @@ class PatchEmbeddings(nn.Module):
return x
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention
class ViTMAESelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -376,6 +375,7 @@ class ViTMAESelfAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE
class ViTMAESelfOutput(nn.Module):
"""
The residual connection is defined in ViTMAELayer instead of here (as is the case with other models), due to the
......@@ -395,6 +395,7 @@ class ViTMAESelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE
class ViTMAEAttention(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -429,6 +430,7 @@ class ViTMAEAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate
class ViTMAEIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -446,6 +448,7 @@ class ViTMAEIntermediate(nn.Module):
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTOutput
class ViTMAEOutput(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -461,6 +464,7 @@ class ViTMAEOutput(nn.Module):
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE
class ViTMAELayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
......@@ -488,7 +492,6 @@ class ViTMAELayer(nn.Module):
# in ViTMAE, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
# second residual connection is done here
......@@ -498,12 +501,8 @@ class ViTMAELayer(nn.Module):
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output)
return layer_output
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE
class ViTMAEEncoder(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -568,10 +567,11 @@ class ViTMAEPreTrainedModel(PreTrainedModel):
"""
config_class = ViTMAEConfig
base_model_prefix = "vit_mae"
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._init_weights
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment