Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bd43151a
Unverified
Commit
bd43151a
authored
Jun 14, 2022
by
amyeroberts
Committed by
GitHub
Jun 14, 2022
Browse files
Swin main layer (#17693)
* Swin models call TFSwinMainLayer * Tidy up
parent
3960ce91
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
18 deletions
+68
-18
src/transformers/models/swin/modeling_tf_swin.py
src/transformers/models/swin/modeling_tf_swin.py
+68
-18
No files found.
src/transformers/models/swin/modeling_tf_swin.py
View file @
bd43151a
...
...
@@ -24,7 +24,13 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import
tensorflow
as
tf
from
...activations_tf
import
ACT2FN
from
...modeling_tf_utils
import
TFPreTrainedModel
,
TFSequenceClassificationLoss
,
get_initializer
,
unpack_inputs
from
...modeling_tf_utils
import
(
TFPreTrainedModel
,
TFSequenceClassificationLoss
,
get_initializer
,
keras_serializable
,
unpack_inputs
,
)
from
...tf_utils
import
shape_list
from
...utils
import
(
ModelOutput
,
...
...
@@ -1069,15 +1075,14 @@ class AdaptiveAveragePooling1D(tf.keras.layers.Layer):
return
{
**
base_config
,
**
config
}
@
add_start_docstrings
(
"The bare Swin Model transformer outputting raw hidden-states without any specific head on top."
,
SWIN_START_DOCSTRING
,
)
class
TFSwinModel
(
TFSwinPreTrainedModel
):
@
keras_serializable
class
TFSwinMainLayer
(
tf
.
keras
.
layers
.
Layer
):
config_class
=
SwinConfig
def
__init__
(
self
,
config
:
SwinConfig
,
add_pooling_layer
:
bool
=
True
,
use_mask_token
:
bool
=
False
,
**
kwargs
)
->
None
:
super
().
__init__
(
config
,
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
config
=
config
self
.
num_layers
=
len
(
config
.
depths
)
self
.
num_features
=
int
(
config
.
embed_dim
*
2
**
(
self
.
num_layers
-
1
))
...
...
@@ -1104,15 +1109,6 @@ class TFSwinModel(TFSwinPreTrainedModel):
raise
NotImplementedError
return
[
None
]
*
len
(
self
.
config
.
depths
)
@
add_start_docstrings_to_model_forward
(
SWIN_INPUTS_DOCSTRING
)
@
add_code_sample_docstrings
(
processor_class
=
_FEAT_EXTRACTOR_FOR_DOC
,
checkpoint
=
_CHECKPOINT_FOR_DOC
,
output_type
=
TFSwinModelOutput
,
config_class
=
_CONFIG_FOR_DOC
,
modality
=
"vision"
,
expected_output
=
_EXPECTED_OUTPUT_SHAPE
,
)
@
unpack_inputs
def
call
(
self
,
...
...
@@ -1175,6 +1171,60 @@ class TFSwinModel(TFSwinPreTrainedModel):
)
@
add_start_docstrings
(
"The bare Swin Model transformer outputting raw hidden-states without any specific head on top."
,
SWIN_START_DOCSTRING
,
)
class
TFSwinModel
(
TFSwinPreTrainedModel
):
def
__init__
(
self
,
config
:
SwinConfig
,
add_pooling_layer
:
bool
=
True
,
use_mask_token
:
bool
=
False
,
**
kwargs
)
->
None
:
super
().
__init__
(
config
,
**
kwargs
)
self
.
config
=
config
self
.
swin
=
TFSwinMainLayer
(
config
,
name
=
"swin"
)
@
add_start_docstrings_to_model_forward
(
SWIN_INPUTS_DOCSTRING
)
@
add_code_sample_docstrings
(
processor_class
=
_FEAT_EXTRACTOR_FOR_DOC
,
checkpoint
=
_CHECKPOINT_FOR_DOC
,
output_type
=
TFSwinModelOutput
,
config_class
=
_CONFIG_FOR_DOC
,
modality
=
"vision"
,
expected_output
=
_EXPECTED_OUTPUT_SHAPE
,
)
@
unpack_inputs
def
call
(
self
,
pixel_values
:
Optional
[
tf
.
Tensor
]
=
None
,
bool_masked_pos
:
Optional
[
tf
.
Tensor
]
=
None
,
head_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
training
:
bool
=
False
,
)
->
Union
[
TFSwinModelOutput
,
Tuple
[
tf
.
Tensor
,
...]]:
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
if
pixel_values
is
None
:
raise
ValueError
(
"You have to specify pixel_values"
)
swin_outputs
=
self
.
swin
(
pixel_values
=
pixel_values
,
bool_masked_pos
=
bool_masked_pos
,
head_mask
=
head_mask
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
)
return
swin_outputs
class
PixelShuffle
(
tf
.
keras
.
layers
.
Layer
):
"""TF layer implementation of torch.nn.PixelShuffle"""
...
...
@@ -1238,7 +1288,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
def
__init__
(
self
,
config
:
SwinConfig
):
super
().
__init__
(
config
)
self
.
swin
=
TFSwinM
odel
(
config
,
add_pooling_layer
=
False
,
use_mask_token
=
True
,
name
=
"swin"
)
self
.
swin
=
TFSwinM
ainLayer
(
config
,
add_pooling_layer
=
False
,
use_mask_token
=
True
,
name
=
"swin"
)
self
.
decoder
=
TFSwinDecoder
(
config
,
name
=
"decoder"
)
...
...
@@ -1350,7 +1400,7 @@ class TFSwinForImageClassification(TFSwinPreTrainedModel, TFSequenceClassificati
super
().
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
swin
=
TFSwinM
odel
(
config
,
name
=
"swin"
)
self
.
swin
=
TFSwinM
ainLayer
(
config
,
name
=
"swin"
)
# Classifier head
self
.
classifier
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment