Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bd43151a
Unverified
Commit
bd43151a
authored
Jun 14, 2022
by
amyeroberts
Committed by
GitHub
Jun 14, 2022
Browse files
Swin main layer (#17693)
* Swin models call TFSwinMainLayer * Tidy up
parent
3960ce91
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
18 deletions
+68
-18
src/transformers/models/swin/modeling_tf_swin.py
src/transformers/models/swin/modeling_tf_swin.py
+68
-18
No files found.
src/transformers/models/swin/modeling_tf_swin.py
View file @
bd43151a
...
...
@@ -24,7 +24,13 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import
tensorflow
as
tf
from
...activations_tf
import
ACT2FN
from
...modeling_tf_utils
import
TFPreTrainedModel
,
TFSequenceClassificationLoss
,
get_initializer
,
unpack_inputs
from
...modeling_tf_utils
import
(
TFPreTrainedModel
,
TFSequenceClassificationLoss
,
get_initializer
,
keras_serializable
,
unpack_inputs
,
)
from
...tf_utils
import
shape_list
from
...utils
import
(
ModelOutput
,
...
...
@@ -1069,15 +1075,14 @@ class AdaptiveAveragePooling1D(tf.keras.layers.Layer):
return
{
**
base_config
,
**
config
}
@
add_start_docstrings
(
"The bare Swin Model transformer outputting raw hidden-states without any specific head on top."
,
SWIN_START_DOCSTRING
,
)
class
TFSwinModel
(
TFSwinPreTrainedModel
):
@
keras_serializable
class
TFSwinMainLayer
(
tf
.
keras
.
layers
.
Layer
):
config_class
=
SwinConfig
def
__init__
(
self
,
config
:
SwinConfig
,
add_pooling_layer
:
bool
=
True
,
use_mask_token
:
bool
=
False
,
**
kwargs
)
->
None
:
super
().
__init__
(
config
,
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
config
=
config
self
.
num_layers
=
len
(
config
.
depths
)
self
.
num_features
=
int
(
config
.
embed_dim
*
2
**
(
self
.
num_layers
-
1
))
...
...
@@ -1104,15 +1109,6 @@ class TFSwinModel(TFSwinPreTrainedModel):
raise
NotImplementedError
return
[
None
]
*
len
(
self
.
config
.
depths
)
@
add_start_docstrings_to_model_forward
(
SWIN_INPUTS_DOCSTRING
)
@
add_code_sample_docstrings
(
processor_class
=
_FEAT_EXTRACTOR_FOR_DOC
,
checkpoint
=
_CHECKPOINT_FOR_DOC
,
output_type
=
TFSwinModelOutput
,
config_class
=
_CONFIG_FOR_DOC
,
modality
=
"vision"
,
expected_output
=
_EXPECTED_OUTPUT_SHAPE
,
)
@
unpack_inputs
def
call
(
self
,
...
...
@@ -1175,6 +1171,60 @@ class TFSwinModel(TFSwinPreTrainedModel):
)
@
add_start_docstrings
(
"The bare Swin Model transformer outputting raw hidden-states without any specific head on top."
,
SWIN_START_DOCSTRING
,
)
class
TFSwinModel
(
TFSwinPreTrainedModel
):
def
__init__
(
self
,
config
:
SwinConfig
,
add_pooling_layer
:
bool
=
True
,
use_mask_token
:
bool
=
False
,
**
kwargs
)
->
None
:
super
().
__init__
(
config
,
**
kwargs
)
self
.
config
=
config
self
.
swin
=
TFSwinMainLayer
(
config
,
name
=
"swin"
)
@
add_start_docstrings_to_model_forward
(
SWIN_INPUTS_DOCSTRING
)
@
add_code_sample_docstrings
(
processor_class
=
_FEAT_EXTRACTOR_FOR_DOC
,
checkpoint
=
_CHECKPOINT_FOR_DOC
,
output_type
=
TFSwinModelOutput
,
config_class
=
_CONFIG_FOR_DOC
,
modality
=
"vision"
,
expected_output
=
_EXPECTED_OUTPUT_SHAPE
,
)
@
unpack_inputs
def
call
(
self
,
pixel_values
:
Optional
[
tf
.
Tensor
]
=
None
,
bool_masked_pos
:
Optional
[
tf
.
Tensor
]
=
None
,
head_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
training
:
bool
=
False
,
)
->
Union
[
TFSwinModelOutput
,
Tuple
[
tf
.
Tensor
,
...]]:
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
if
pixel_values
is
None
:
raise
ValueError
(
"You have to specify pixel_values"
)
swin_outputs
=
self
.
swin
(
pixel_values
=
pixel_values
,
bool_masked_pos
=
bool_masked_pos
,
head_mask
=
head_mask
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
)
return
swin_outputs
class
PixelShuffle
(
tf
.
keras
.
layers
.
Layer
):
"""TF layer implementation of torch.nn.PixelShuffle"""
...
...
@@ -1238,7 +1288,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
def
__init__
(
self
,
config
:
SwinConfig
):
super
().
__init__
(
config
)
self
.
swin
=
TFSwinM
odel
(
config
,
add_pooling_layer
=
False
,
use_mask_token
=
True
,
name
=
"swin"
)
self
.
swin
=
TFSwinM
ainLayer
(
config
,
add_pooling_layer
=
False
,
use_mask_token
=
True
,
name
=
"swin"
)
self
.
decoder
=
TFSwinDecoder
(
config
,
name
=
"decoder"
)
...
...
@@ -1350,7 +1400,7 @@ class TFSwinForImageClassification(TFSwinPreTrainedModel, TFSequenceClassificati
super
().
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
swin
=
TFSwinM
odel
(
config
,
name
=
"swin"
)
self
.
swin
=
TFSwinM
ainLayer
(
config
,
name
=
"swin"
)
# Classifier head
self
.
classifier
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment