Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d4b3e56d
Unverified
Commit
d4b3e56d
authored
Jan 31, 2022
by
NielsRogge
Committed by
GitHub
Jan 31, 2022
Browse files
[Hotfix] Fix Swin model outputs (#15414)
* Fix Swin model outputs * Rename pooler
parent
38dfb40a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
21 deletions
+40
-21
src/transformers/models/swin/modeling_swin.py
src/transformers/models/swin/modeling_swin.py
+35
-18
tests/test_modeling_swin.py
tests/test_modeling_swin.py
+5
-3
No files found.
src/transformers/models/swin/modeling_swin.py
View file @
d4b3e56d
...
@@ -21,11 +21,11 @@ import math
...
@@ -21,11 +21,11 @@ import math
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...file_utils
import
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
replace_return_docstrings
from
...file_utils
import
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
replace_return_docstrings
from
...modeling_outputs
import
BaseModelOutput
,
SequenceClassifierOutput
from
...modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
,
SequenceClassifierOutput
from
...modeling_utils
import
PreTrainedModel
,
find_pruneable_heads_and_indices
,
prune_linear_layer
from
...modeling_utils
import
PreTrainedModel
,
find_pruneable_heads_and_indices
,
prune_linear_layer
from
...utils
import
logging
from
...utils
import
logging
from
.configuration_swin
import
SwinConfig
from
.configuration_swin
import
SwinConfig
...
@@ -143,8 +143,8 @@ class SwinPatchEmbeddings(nn.Module):
...
@@ -143,8 +143,8 @@ class SwinPatchEmbeddings(nn.Module):
self
.
projection
=
nn
.
Conv2d
(
num_channels
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
self
.
projection
=
nn
.
Conv2d
(
num_channels
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
def
forward
(
self
,
pixel_values
):
def
forward
(
self
,
pixel_values
):
pixel_value
s
=
self
.
projection
(
pixel_values
).
flatten
(
2
).
transpose
(
1
,
2
)
embedding
s
=
self
.
projection
(
pixel_values
).
flatten
(
2
).
transpose
(
1
,
2
)
return
pixel_value
s
return
embedding
s
class
SwinPatchMerging
(
nn
.
Module
):
class
SwinPatchMerging
(
nn
.
Module
):
...
@@ -659,7 +659,7 @@ SWIN_INPUTS_DOCSTRING = r"""
...
@@ -659,7 +659,7 @@ SWIN_INPUTS_DOCSTRING = r"""
SWIN_START_DOCSTRING
,
SWIN_START_DOCSTRING
,
)
)
class
SwinModel
(
SwinPreTrainedModel
):
class
SwinModel
(
SwinPreTrainedModel
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
add_pooling_layer
=
True
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
config
=
config
self
.
config
=
config
self
.
num_layers
=
len
(
config
.
depths
)
self
.
num_layers
=
len
(
config
.
depths
)
...
@@ -669,7 +669,7 @@ class SwinModel(SwinPreTrainedModel):
...
@@ -669,7 +669,7 @@ class SwinModel(SwinPreTrainedModel):
self
.
encoder
=
SwinEncoder
(
config
,
self
.
embeddings
.
patch_grid
)
self
.
encoder
=
SwinEncoder
(
config
,
self
.
embeddings
.
patch_grid
)
self
.
layernorm
=
nn
.
LayerNorm
(
self
.
num_features
,
eps
=
config
.
layer_norm_eps
)
self
.
layernorm
=
nn
.
LayerNorm
(
self
.
num_features
,
eps
=
config
.
layer_norm_eps
)
self
.
pool
=
nn
.
AdaptiveAvgPool1d
(
1
)
self
.
pool
er
=
nn
.
AdaptiveAvgPool1d
(
1
)
if
add_pooling_layer
else
None
# Initialize weights and apply final processing
# Initialize weights and apply final processing
self
.
post_init
()
self
.
post_init
()
...
@@ -686,7 +686,7 @@ class SwinModel(SwinPreTrainedModel):
...
@@ -686,7 +686,7 @@ class SwinModel(SwinPreTrainedModel):
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
@
add_start_docstrings_to_model_forward
(
SWIN_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
SWIN_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
BaseModelOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
BaseModelOutput
WithPooling
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
def
forward
(
self
,
self
,
pixel_values
=
None
,
pixel_values
=
None
,
...
@@ -744,14 +744,18 @@ class SwinModel(SwinPreTrainedModel):
...
@@ -744,14 +744,18 @@ class SwinModel(SwinPreTrainedModel):
sequence_output
=
encoder_outputs
[
0
]
sequence_output
=
encoder_outputs
[
0
]
sequence_output
=
self
.
layernorm
(
sequence_output
)
sequence_output
=
self
.
layernorm
(
sequence_output
)
sequence_output
=
self
.
pool
(
sequence_output
.
transpose
(
1
,
2
))
sequence_output
=
torch
.
flatten
(
sequence_output
,
1
)
pooled_output
=
None
if
self
.
pooler
is
not
None
:
pooled_output
=
self
.
pooler
(
sequence_output
.
transpose
(
1
,
2
))
pooled_output
=
torch
.
flatten
(
pooled_output
,
1
)
if
not
return_dict
:
if
not
return_dict
:
return
(
sequence_output
,)
+
encoder_outputs
[
1
:]
return
(
sequence_output
,
pooled_output
)
+
encoder_outputs
[
1
:]
return
BaseModelOutput
(
return
BaseModelOutput
WithPooling
(
last_hidden_state
=
sequence_output
,
last_hidden_state
=
sequence_output
,
pooler_output
=
pooled_output
,
hidden_states
=
encoder_outputs
.
hidden_states
,
hidden_states
=
encoder_outputs
.
hidden_states
,
attentions
=
encoder_outputs
.
attentions
,
attentions
=
encoder_outputs
.
attentions
,
)
)
...
@@ -829,22 +833,35 @@ class SwinForImageClassification(SwinPreTrainedModel):
...
@@ -829,22 +833,35 @@ class SwinForImageClassification(SwinPreTrainedModel):
return_dict
=
return_dict
,
return_dict
=
return_dict
,
)
)
sequence
_output
=
outputs
[
0
]
pooled
_output
=
outputs
[
1
]
logits
=
self
.
classifier
(
sequence
_output
)
logits
=
self
.
classifier
(
pooled
_output
)
loss
=
None
loss
=
None
if
labels
is
not
None
:
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
if
self
.
config
.
problem_type
is
None
:
# We are doing regression
if
self
.
num_labels
==
1
:
self
.
config
.
problem_type
=
"regression"
elif
self
.
num_labels
>
1
and
(
labels
.
dtype
==
torch
.
long
or
labels
.
dtype
==
torch
.
int
):
self
.
config
.
problem_type
=
"single_label_classification"
else
:
self
.
config
.
problem_type
=
"multi_label_classification"
if
self
.
config
.
problem_type
==
"regression"
:
loss_fct
=
MSELoss
()
loss_fct
=
MSELoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
),
labels
.
view
(
-
1
))
if
self
.
num_labels
==
1
:
else
:
loss
=
loss_fct
(
logits
.
squeeze
(),
labels
.
squeeze
())
else
:
loss
=
loss_fct
(
logits
,
labels
)
elif
self
.
config
.
problem_type
==
"single_label_classification"
:
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
elif
self
.
config
.
problem_type
==
"multi_label_classification"
:
loss_fct
=
BCEWithLogitsLoss
()
loss
=
loss_fct
(
logits
,
labels
)
if
not
return_dict
:
if
not
return_dict
:
output
=
(
logits
,)
+
outputs
[
1
:]
output
=
(
logits
,)
+
outputs
[
2
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
SequenceClassifierOutput
(
return
SequenceClassifierOutput
(
...
...
tests/test_modeling_swin.py
View file @
d4b3e56d
...
@@ -137,9 +137,11 @@ class SwinModelTester:
...
@@ -137,9 +137,11 @@ class SwinModelTester:
model
.
eval
()
model
.
eval
()
result
=
model
(
pixel_values
)
result
=
model
(
pixel_values
)
num_features
=
int
(
config
.
embed_dim
*
2
**
(
len
(
config
.
depths
)
-
1
))
# since the model we're testing only consists of a single layer, expected_seq_len = number of patches
expected_seq_len
=
(
config
.
image_size
//
config
.
patch_size
)
**
2
expected_dim
=
int
(
config
.
embed_dim
*
2
**
(
len
(
config
.
depths
)
-
1
))
self
.
parent
.
assertEqual
(
result
.
last_hidden_state
.
shape
,
(
self
.
batch_size
,
num_features
))
self
.
parent
.
assertEqual
(
result
.
last_hidden_state
.
shape
,
(
self
.
batch_size
,
expected_seq_len
,
expected_dim
))
def
create_and_check_for_image_classification
(
self
,
config
,
pixel_values
,
labels
):
def
create_and_check_for_image_classification
(
self
,
config
,
pixel_values
,
labels
):
config
.
num_labels
=
self
.
type_sequence_label_size
config
.
num_labels
=
self
.
type_sequence_label_size
...
@@ -392,6 +394,6 @@ class SwinModelIntegrationTest(unittest.TestCase):
...
@@ -392,6 +394,6 @@ class SwinModelIntegrationTest(unittest.TestCase):
expected_shape
=
torch
.
Size
((
1
,
1000
))
expected_shape
=
torch
.
Size
((
1
,
1000
))
self
.
assertEqual
(
outputs
.
logits
.
shape
,
expected_shape
)
self
.
assertEqual
(
outputs
.
logits
.
shape
,
expected_shape
)
expected_slice
=
torch
.
tensor
([
-
0.
2952
,
-
0.
4777
,
0.
2025
]).
to
(
torch_device
)
expected_slice
=
torch
.
tensor
([
-
0.
0948
,
-
0.
6454
,
-
0.
0921
]).
to
(
torch_device
)
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
logits
[
0
,
:
3
],
expected_slice
,
atol
=
1e-4
))
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
logits
[
0
,
:
3
],
expected_slice
,
atol
=
1e-4
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment