Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1dfc11e9
Unverified
Commit
1dfc11e9
authored
Mar 23, 2022
by
João Gustavo A. Amorim
Committed by
GitHub
Mar 23, 2022
Browse files
complete the type annotations for config parameters (#16263)
parent
bb3a1d34
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
33 deletions
+33
-33
src/transformers/models/deit/modeling_deit.py
src/transformers/models/deit/modeling_deit.py
+8
-8
src/transformers/models/vilt/modeling_vilt.py
src/transformers/models/vilt/modeling_vilt.py
+3
-3
src/transformers/models/vit/modeling_vit.py
src/transformers/models/vit/modeling_vit.py
+11
-11
src/transformers/models/vit_mae/modeling_vit_mae.py
src/transformers/models/vit_mae/modeling_vit_mae.py
+11
-11
No files found.
src/transformers/models/deit/modeling_deit.py
View file @
1dfc11e9
...
@@ -147,7 +147,7 @@ class PatchEmbeddings(nn.Module):
...
@@ -147,7 +147,7 @@ class PatchEmbeddings(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT
class
DeiTSelfAttention
(
nn
.
Module
):
class
DeiTSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
DeiTConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
and
not
hasattr
(
config
,
"embedding_size"
):
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
and
not
hasattr
(
config
,
"embedding_size"
):
raise
ValueError
(
raise
ValueError
(
...
@@ -213,7 +213,7 @@ class DeiTSelfOutput(nn.Module):
...
@@ -213,7 +213,7 @@ class DeiTSelfOutput(nn.Module):
layernorm applied before each block.
layernorm applied before each block.
"""
"""
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
DeiTConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
...
@@ -228,7 +228,7 @@ class DeiTSelfOutput(nn.Module):
...
@@ -228,7 +228,7 @@ class DeiTSelfOutput(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->DeiT
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->DeiT
class
DeiTAttention
(
nn
.
Module
):
class
DeiTAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
DeiTConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
attention
=
DeiTSelfAttention
(
config
)
self
.
attention
=
DeiTSelfAttention
(
config
)
self
.
output
=
DeiTSelfOutput
(
config
)
self
.
output
=
DeiTSelfOutput
(
config
)
...
@@ -268,7 +268,7 @@ class DeiTAttention(nn.Module):
...
@@ -268,7 +268,7 @@ class DeiTAttention(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
class
DeiTIntermediate
(
nn
.
Module
):
class
DeiTIntermediate
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
DeiTConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
if
isinstance
(
config
.
hidden_act
,
str
):
if
isinstance
(
config
.
hidden_act
,
str
):
...
@@ -286,7 +286,7 @@ class DeiTIntermediate(nn.Module):
...
@@ -286,7 +286,7 @@ class DeiTIntermediate(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DeiT
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DeiT
class
DeiTOutput
(
nn
.
Module
):
class
DeiTOutput
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
DeiTConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
...
@@ -304,7 +304,7 @@ class DeiTOutput(nn.Module):
...
@@ -304,7 +304,7 @@ class DeiTOutput(nn.Module):
class
DeiTLayer
(
nn
.
Module
):
class
DeiTLayer
(
nn
.
Module
):
"""This corresponds to the Block class in the timm implementation."""
"""This corresponds to the Block class in the timm implementation."""
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
DeiTConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
chunk_size_feed_forward
=
config
.
chunk_size_feed_forward
self
.
chunk_size_feed_forward
=
config
.
chunk_size_feed_forward
self
.
seq_len_dim
=
1
self
.
seq_len_dim
=
1
...
@@ -345,7 +345,7 @@ class DeiTLayer(nn.Module):
...
@@ -345,7 +345,7 @@ class DeiTLayer(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->DeiT
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->DeiT
class
DeiTEncoder
(
nn
.
Module
):
class
DeiTEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
DeiTConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
layer
=
nn
.
ModuleList
([
DeiTLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layer
=
nn
.
ModuleList
([
DeiTLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
...
@@ -553,7 +553,7 @@ class DeiTModel(DeiTPreTrainedModel):
...
@@ -553,7 +553,7 @@ class DeiTModel(DeiTPreTrainedModel):
# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->DeiT
# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->DeiT
class
DeiTPooler
(
nn
.
Module
):
class
DeiTPooler
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
:
DeiTConfig
):
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
self
.
activation
=
nn
.
Tanh
()
...
...
src/transformers/models/vilt/modeling_vilt.py
View file @
1dfc11e9
...
@@ -388,7 +388,7 @@ class ViltSelfOutput(nn.Module):
...
@@ -388,7 +388,7 @@ class ViltSelfOutput(nn.Module):
layernorm applied before each block.
layernorm applied before each block.
"""
"""
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViltConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
...
@@ -437,7 +437,7 @@ class ViltAttention(nn.Module):
...
@@ -437,7 +437,7 @@ class ViltAttention(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Vilt
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Vilt
class
ViltIntermediate
(
nn
.
Module
):
class
ViltIntermediate
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViltConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
if
isinstance
(
config
.
hidden_act
,
str
):
if
isinstance
(
config
.
hidden_act
,
str
):
...
@@ -455,7 +455,7 @@ class ViltIntermediate(nn.Module):
...
@@ -455,7 +455,7 @@ class ViltIntermediate(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Vilt
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Vilt
class
ViltOutput
(
nn
.
Module
):
class
ViltOutput
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViltConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
...
...
src/transformers/models/vit/modeling_vit.py
View file @
1dfc11e9
...
@@ -77,7 +77,7 @@ class ViTEmbeddings(nn.Module):
...
@@ -77,7 +77,7 @@ class ViTEmbeddings(nn.Module):
"""
"""
def
__init__
(
self
,
config
,
use_mask_token
:
bool
=
False
)
->
None
:
def
__init__
(
self
,
config
:
ViTConfig
,
use_mask_token
:
bool
=
False
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
config
.
hidden_size
))
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
config
.
hidden_size
))
...
@@ -192,7 +192,7 @@ class PatchEmbeddings(nn.Module):
...
@@ -192,7 +192,7 @@ class PatchEmbeddings(nn.Module):
class
ViTSelfAttention
(
nn
.
Module
):
class
ViTSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
and
not
hasattr
(
config
,
"embedding_size"
):
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
and
not
hasattr
(
config
,
"embedding_size"
):
raise
ValueError
(
raise
ValueError
(
...
@@ -257,7 +257,7 @@ class ViTSelfOutput(nn.Module):
...
@@ -257,7 +257,7 @@ class ViTSelfOutput(nn.Module):
layernorm applied before each block.
layernorm applied before each block.
"""
"""
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
...
@@ -271,7 +271,7 @@ class ViTSelfOutput(nn.Module):
...
@@ -271,7 +271,7 @@ class ViTSelfOutput(nn.Module):
class
ViTAttention
(
nn
.
Module
):
class
ViTAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
attention
=
ViTSelfAttention
(
config
)
self
.
attention
=
ViTSelfAttention
(
config
)
self
.
output
=
ViTSelfOutput
(
config
)
self
.
output
=
ViTSelfOutput
(
config
)
...
@@ -310,7 +310,7 @@ class ViTAttention(nn.Module):
...
@@ -310,7 +310,7 @@ class ViTAttention(nn.Module):
class
ViTIntermediate
(
nn
.
Module
):
class
ViTIntermediate
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
if
isinstance
(
config
.
hidden_act
,
str
):
if
isinstance
(
config
.
hidden_act
,
str
):
...
@@ -327,7 +327,7 @@ class ViTIntermediate(nn.Module):
...
@@ -327,7 +327,7 @@ class ViTIntermediate(nn.Module):
class
ViTOutput
(
nn
.
Module
):
class
ViTOutput
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
...
@@ -344,7 +344,7 @@ class ViTOutput(nn.Module):
...
@@ -344,7 +344,7 @@ class ViTOutput(nn.Module):
class
ViTLayer
(
nn
.
Module
):
class
ViTLayer
(
nn
.
Module
):
"""This corresponds to the Block class in the timm implementation."""
"""This corresponds to the Block class in the timm implementation."""
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
chunk_size_feed_forward
=
config
.
chunk_size_feed_forward
self
.
chunk_size_feed_forward
=
config
.
chunk_size_feed_forward
self
.
seq_len_dim
=
1
self
.
seq_len_dim
=
1
...
@@ -384,7 +384,7 @@ class ViTLayer(nn.Module):
...
@@ -384,7 +384,7 @@ class ViTLayer(nn.Module):
class
ViTEncoder
(
nn
.
Module
):
class
ViTEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
layer
=
nn
.
ModuleList
([
ViTLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layer
=
nn
.
ModuleList
([
ViTLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
...
@@ -595,7 +595,7 @@ class ViTModel(ViTPreTrainedModel):
...
@@ -595,7 +595,7 @@ class ViTModel(ViTPreTrainedModel):
class
ViTPooler
(
nn
.
Module
):
class
ViTPooler
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
:
ViTConfig
):
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
self
.
activation
=
nn
.
Tanh
()
...
@@ -614,7 +614,7 @@ class ViTPooler(nn.Module):
...
@@ -614,7 +614,7 @@ class ViTPooler(nn.Module):
VIT_START_DOCSTRING
,
VIT_START_DOCSTRING
,
)
)
class
ViTForMaskedImageModeling
(
ViTPreTrainedModel
):
class
ViTForMaskedImageModeling
(
ViTPreTrainedModel
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTConfig
)
->
None
:
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
vit
=
ViTModel
(
config
,
add_pooling_layer
=
False
,
use_mask_token
=
True
)
self
.
vit
=
ViTModel
(
config
,
add_pooling_layer
=
False
,
use_mask_token
=
True
)
...
@@ -724,7 +724,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
...
@@ -724,7 +724,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
VIT_START_DOCSTRING
,
VIT_START_DOCSTRING
,
)
)
class
ViTForImageClassification
(
ViTPreTrainedModel
):
class
ViTForImageClassification
(
ViTPreTrainedModel
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTConfig
)
->
None
:
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
num_labels
=
config
.
num_labels
...
...
src/transformers/models/vit_mae/modeling_vit_mae.py
View file @
1dfc11e9
...
@@ -134,7 +134,7 @@ class ViTMAEForPreTrainingOutput(ModelOutput):
...
@@ -134,7 +134,7 @@ class ViTMAEForPreTrainingOutput(ModelOutput):
attentions
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
attentions
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
# copied from transformers.models.vit.modeling_vit.to_2tuple
# copied from transformers.models.vit.modeling_vit.to_2tuple
ViT->ViTMAE
def
to_2tuple
(
x
):
def
to_2tuple
(
x
):
if
isinstance
(
x
,
collections
.
abc
.
Iterable
):
if
isinstance
(
x
,
collections
.
abc
.
Iterable
):
return
x
return
x
...
@@ -316,9 +316,9 @@ class PatchEmbeddings(nn.Module):
...
@@ -316,9 +316,9 @@ class PatchEmbeddings(nn.Module):
return
x
return
x
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention
ViT->ViTMAE
class
ViTMAESelfAttention
(
nn
.
Module
):
class
ViTMAESelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTMAEConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
and
not
hasattr
(
config
,
"embedding_size"
):
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
and
not
hasattr
(
config
,
"embedding_size"
):
raise
ValueError
(
raise
ValueError
(
...
@@ -384,7 +384,7 @@ class ViTMAESelfOutput(nn.Module):
...
@@ -384,7 +384,7 @@ class ViTMAESelfOutput(nn.Module):
layernorm applied before each block.
layernorm applied before each block.
"""
"""
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTMAEConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
...
@@ -399,7 +399,7 @@ class ViTMAESelfOutput(nn.Module):
...
@@ -399,7 +399,7 @@ class ViTMAESelfOutput(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE
class
ViTMAEAttention
(
nn
.
Module
):
class
ViTMAEAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTMAEConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
attention
=
ViTMAESelfAttention
(
config
)
self
.
attention
=
ViTMAESelfAttention
(
config
)
self
.
output
=
ViTMAESelfOutput
(
config
)
self
.
output
=
ViTMAESelfOutput
(
config
)
...
@@ -437,9 +437,9 @@ class ViTMAEAttention(nn.Module):
...
@@ -437,9 +437,9 @@ class ViTMAEAttention(nn.Module):
return
outputs
return
outputs
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate
ViT->ViTMAE
class
ViTMAEIntermediate
(
nn
.
Module
):
class
ViTMAEIntermediate
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTMAEConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
if
isinstance
(
config
.
hidden_act
,
str
):
if
isinstance
(
config
.
hidden_act
,
str
):
...
@@ -455,9 +455,9 @@ class ViTMAEIntermediate(nn.Module):
...
@@ -455,9 +455,9 @@ class ViTMAEIntermediate(nn.Module):
return
hidden_states
return
hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTOutput
# Copied from transformers.models.vit.modeling_vit.ViTOutput
ViT->ViTMAE
class
ViTMAEOutput
(
nn
.
Module
):
class
ViTMAEOutput
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTMAEConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
...
@@ -475,7 +475,7 @@ class ViTMAEOutput(nn.Module):
...
@@ -475,7 +475,7 @@ class ViTMAEOutput(nn.Module):
class
ViTMAELayer
(
nn
.
Module
):
class
ViTMAELayer
(
nn
.
Module
):
"""This corresponds to the Block class in the timm implementation."""
"""This corresponds to the Block class in the timm implementation."""
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTMAEConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
chunk_size_feed_forward
=
config
.
chunk_size_feed_forward
self
.
chunk_size_feed_forward
=
config
.
chunk_size_feed_forward
self
.
seq_len_dim
=
1
self
.
seq_len_dim
=
1
...
@@ -516,7 +516,7 @@ class ViTMAELayer(nn.Module):
...
@@ -516,7 +516,7 @@ class ViTMAELayer(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE
class
ViTMAEEncoder
(
nn
.
Module
):
class
ViTMAEEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
)
->
None
:
def
__init__
(
self
,
config
:
ViTMAEConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
layer
=
nn
.
ModuleList
([
ViTMAELayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layer
=
nn
.
ModuleList
([
ViTMAELayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment