Unverified Commit 9fd584e5 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add copied from statements and fix prefix (#16119)


Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent f284aa32
...@@ -134,9 +134,7 @@ class ViTMAEForPreTrainingOutput(ModelOutput): ...@@ -134,9 +134,7 @@ class ViTMAEForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
# Inspired by # copied from transformers.models.vit.modeling_vit.to_2tuple
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
# From PyTorch internals
def to_2tuple(x): def to_2tuple(x):
if isinstance(x, collections.abc.Iterable): if isinstance(x, collections.abc.Iterable):
return x return x
...@@ -318,6 +316,7 @@ class PatchEmbeddings(nn.Module): ...@@ -318,6 +316,7 @@ class PatchEmbeddings(nn.Module):
return x return x
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention
class ViTMAESelfAttention(nn.Module): class ViTMAESelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -376,6 +375,7 @@ class ViTMAESelfAttention(nn.Module): ...@@ -376,6 +375,7 @@ class ViTMAESelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE
class ViTMAESelfOutput(nn.Module): class ViTMAESelfOutput(nn.Module):
""" """
The residual connection is defined in ViTMAELayer instead of here (as is the case with other models), due to the The residual connection is defined in ViTMAELayer instead of here (as is the case with other models), due to the
...@@ -395,6 +395,7 @@ class ViTMAESelfOutput(nn.Module): ...@@ -395,6 +395,7 @@ class ViTMAESelfOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE
class ViTMAEAttention(nn.Module): class ViTMAEAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -429,6 +430,7 @@ class ViTMAEAttention(nn.Module): ...@@ -429,6 +430,7 @@ class ViTMAEAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate
class ViTMAEIntermediate(nn.Module): class ViTMAEIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -446,6 +448,7 @@ class ViTMAEIntermediate(nn.Module): ...@@ -446,6 +448,7 @@ class ViTMAEIntermediate(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTOutput
class ViTMAEOutput(nn.Module): class ViTMAEOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -461,6 +464,7 @@ class ViTMAEOutput(nn.Module): ...@@ -461,6 +464,7 @@ class ViTMAEOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE
class ViTMAELayer(nn.Module): class ViTMAELayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
...@@ -488,7 +492,6 @@ class ViTMAELayer(nn.Module): ...@@ -488,7 +492,6 @@ class ViTMAELayer(nn.Module):
# in ViTMAE, layernorm is also applied after self-attention # in ViTMAE, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states) layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output) layer_output = self.intermediate(layer_output)
# second residual connection is done here # second residual connection is done here
...@@ -498,12 +501,8 @@ class ViTMAELayer(nn.Module): ...@@ -498,12 +501,8 @@ class ViTMAELayer(nn.Module):
return outputs return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output)
return layer_output
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE
class ViTMAEEncoder(nn.Module): class ViTMAEEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -568,10 +567,11 @@ class ViTMAEPreTrainedModel(PreTrainedModel): ...@@ -568,10 +567,11 @@ class ViTMAEPreTrainedModel(PreTrainedModel):
""" """
config_class = ViTMAEConfig config_class = ViTMAEConfig
base_model_prefix = "vit_mae" base_model_prefix = "vit"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)): if isinstance(module, (nn.Linear, nn.Conv2d)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment