Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d4b3e56d
Unverified
Commit
d4b3e56d
authored
Jan 31, 2022
by
NielsRogge
Committed by
GitHub
Jan 31, 2022
Browse files
[Hotfix] Fix Swin model outputs (#15414)
* Fix Swin model outputs * Rename pooler
parent
38dfb40a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
21 deletions
+40
-21
src/transformers/models/swin/modeling_swin.py
src/transformers/models/swin/modeling_swin.py
+35
-18
tests/test_modeling_swin.py
tests/test_modeling_swin.py
+5
-3
No files found.
src/transformers/models/swin/modeling_swin.py
View file @
d4b3e56d
...
@@ -21,11 +21,11 @@ import math
...
@@ -21,11 +21,11 @@ import math
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...file_utils
import
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
replace_return_docstrings
from
...file_utils
import
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
replace_return_docstrings
from
...modeling_outputs
import
BaseModelOutput
,
SequenceClassifierOutput
from
...modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
,
SequenceClassifierOutput
from
...modeling_utils
import
PreTrainedModel
,
find_pruneable_heads_and_indices
,
prune_linear_layer
from
...modeling_utils
import
PreTrainedModel
,
find_pruneable_heads_and_indices
,
prune_linear_layer
from
...utils
import
logging
from
...utils
import
logging
from
.configuration_swin
import
SwinConfig
from
.configuration_swin
import
SwinConfig
...
@@ -143,8 +143,8 @@ class SwinPatchEmbeddings(nn.Module):
...
@@ -143,8 +143,8 @@ class SwinPatchEmbeddings(nn.Module):
self
.
projection
=
nn
.
Conv2d
(
num_channels
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
self
.
projection
=
nn
.
Conv2d
(
num_channels
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
def
forward
(
self
,
pixel_values
):
def
forward
(
self
,
pixel_values
):
pixel_value
s
=
self
.
projection
(
pixel_values
).
flatten
(
2
).
transpose
(
1
,
2
)
embedding
s
=
self
.
projection
(
pixel_values
).
flatten
(
2
).
transpose
(
1
,
2
)
return
pixel_value
s
return
embedding
s
class
SwinPatchMerging
(
nn
.
Module
):
class
SwinPatchMerging
(
nn
.
Module
):
...
@@ -659,7 +659,7 @@ SWIN_INPUTS_DOCSTRING = r"""
...
@@ -659,7 +659,7 @@ SWIN_INPUTS_DOCSTRING = r"""
SWIN_START_DOCSTRING
,
SWIN_START_DOCSTRING
,
)
)
class
SwinModel
(
SwinPreTrainedModel
):
class
SwinModel
(
SwinPreTrainedModel
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
add_pooling_layer
=
True
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
config
=
config
self
.
config
=
config
self
.
num_layers
=
len
(
config
.
depths
)
self
.
num_layers
=
len
(
config
.
depths
)
...
@@ -669,7 +669,7 @@ class SwinModel(SwinPreTrainedModel):
...
@@ -669,7 +669,7 @@ class SwinModel(SwinPreTrainedModel):
self
.
encoder
=
SwinEncoder
(
config
,
self
.
embeddings
.
patch_grid
)
self
.
encoder
=
SwinEncoder
(
config
,
self
.
embeddings
.
patch_grid
)
self
.
layernorm
=
nn
.
LayerNorm
(
self
.
num_features
,
eps
=
config
.
layer_norm_eps
)
self
.
layernorm
=
nn
.
LayerNorm
(
self
.
num_features
,
eps
=
config
.
layer_norm_eps
)
self
.
pool
=
nn
.
AdaptiveAvgPool1d
(
1
)
self
.
pool
er
=
nn
.
AdaptiveAvgPool1d
(
1
)
if
add_pooling_layer
else
None
# Initialize weights and apply final processing
# Initialize weights and apply final processing
self
.
post_init
()
self
.
post_init
()
...
@@ -686,7 +686,7 @@ class SwinModel(SwinPreTrainedModel):
...
@@ -686,7 +686,7 @@ class SwinModel(SwinPreTrainedModel):
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
@
add_start_docstrings_to_model_forward
(
SWIN_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
SWIN_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
BaseModelOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
BaseModelOutput
WithPooling
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
def
forward
(
self
,
self
,
pixel_values
=
None
,
pixel_values
=
None
,
...
@@ -744,14 +744,18 @@ class SwinModel(SwinPreTrainedModel):
...
@@ -744,14 +744,18 @@ class SwinModel(SwinPreTrainedModel):
sequence_output
=
encoder_outputs
[
0
]
sequence_output
=
encoder_outputs
[
0
]
sequence_output
=
self
.
layernorm
(
sequence_output
)
sequence_output
=
self
.
layernorm
(
sequence_output
)
sequence_output
=
self
.
pool
(
sequence_output
.
transpose
(
1
,
2
))
sequence_output
=
torch
.
flatten
(
sequence_output
,
1
)
pooled_output
=
None
if
self
.
pooler
is
not
None
:
pooled_output
=
self
.
pooler
(
sequence_output
.
transpose
(
1
,
2
))
pooled_output
=
torch
.
flatten
(
pooled_output
,
1
)
if
not
return_dict
:
if
not
return_dict
:
return
(
sequence_output
,)
+
encoder_outputs
[
1
:]
return
(
sequence_output
,
pooled_output
)
+
encoder_outputs
[
1
:]
return
BaseModelOutput
(
return
BaseModelOutput
WithPooling
(
last_hidden_state
=
sequence_output
,
last_hidden_state
=
sequence_output
,
pooler_output
=
pooled_output
,
hidden_states
=
encoder_outputs
.
hidden_states
,
hidden_states
=
encoder_outputs
.
hidden_states
,
attentions
=
encoder_outputs
.
attentions
,
attentions
=
encoder_outputs
.
attentions
,
)
)
...
@@ -829,22 +833,35 @@ class SwinForImageClassification(SwinPreTrainedModel):
...
@@ -829,22 +833,35 @@ class SwinForImageClassification(SwinPreTrainedModel):
return_dict
=
return_dict
,
return_dict
=
return_dict
,
)
)
sequence
_output
=
outputs
[
0
]
pooled
_output
=
outputs
[
1
]
logits
=
self
.
classifier
(
sequence
_output
)
logits
=
self
.
classifier
(
pooled
_output
)
loss
=
None
loss
=
None
if
labels
is
not
None
:
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
if
self
.
config
.
problem_type
is
None
:
# We are doing regression
if
self
.
num_labels
==
1
:
self
.
config
.
problem_type
=
"regression"
elif
self
.
num_labels
>
1
and
(
labels
.
dtype
==
torch
.
long
or
labels
.
dtype
==
torch
.
int
):
self
.
config
.
problem_type
=
"single_label_classification"
else
:
self
.
config
.
problem_type
=
"multi_label_classification"
if
self
.
config
.
problem_type
==
"regression"
:
loss_fct
=
MSELoss
()
loss_fct
=
MSELoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
),
labels
.
view
(
-
1
))
if
self
.
num_labels
==
1
:
else
:
loss
=
loss_fct
(
logits
.
squeeze
(),
labels
.
squeeze
())
else
:
loss
=
loss_fct
(
logits
,
labels
)
elif
self
.
config
.
problem_type
==
"single_label_classification"
:
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
elif
self
.
config
.
problem_type
==
"multi_label_classification"
:
loss_fct
=
BCEWithLogitsLoss
()
loss
=
loss_fct
(
logits
,
labels
)
if
not
return_dict
:
if
not
return_dict
:
output
=
(
logits
,)
+
outputs
[
1
:]
output
=
(
logits
,)
+
outputs
[
2
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
SequenceClassifierOutput
(
return
SequenceClassifierOutput
(
...
...
tests/test_modeling_swin.py
View file @
d4b3e56d
...
@@ -137,9 +137,11 @@ class SwinModelTester:
...
@@ -137,9 +137,11 @@ class SwinModelTester:
model
.
eval
()
model
.
eval
()
result
=
model
(
pixel_values
)
result
=
model
(
pixel_values
)
num_features
=
int
(
config
.
embed_dim
*
2
**
(
len
(
config
.
depths
)
-
1
))
# since the model we're testing only consists of a single layer, expected_seq_len = number of patches
expected_seq_len
=
(
config
.
image_size
//
config
.
patch_size
)
**
2
expected_dim
=
int
(
config
.
embed_dim
*
2
**
(
len
(
config
.
depths
)
-
1
))
self
.
parent
.
assertEqual
(
result
.
last_hidden_state
.
shape
,
(
self
.
batch_size
,
num_features
))
self
.
parent
.
assertEqual
(
result
.
last_hidden_state
.
shape
,
(
self
.
batch_size
,
expected_seq_len
,
expected_dim
))
def
create_and_check_for_image_classification
(
self
,
config
,
pixel_values
,
labels
):
def
create_and_check_for_image_classification
(
self
,
config
,
pixel_values
,
labels
):
config
.
num_labels
=
self
.
type_sequence_label_size
config
.
num_labels
=
self
.
type_sequence_label_size
...
@@ -392,6 +394,6 @@ class SwinModelIntegrationTest(unittest.TestCase):
...
@@ -392,6 +394,6 @@ class SwinModelIntegrationTest(unittest.TestCase):
expected_shape
=
torch
.
Size
((
1
,
1000
))
expected_shape
=
torch
.
Size
((
1
,
1000
))
self
.
assertEqual
(
outputs
.
logits
.
shape
,
expected_shape
)
self
.
assertEqual
(
outputs
.
logits
.
shape
,
expected_shape
)
expected_slice
=
torch
.
tensor
([
-
0.
2952
,
-
0.
4777
,
0.
2025
]).
to
(
torch_device
)
expected_slice
=
torch
.
tensor
([
-
0.
0948
,
-
0.
6454
,
-
0.
0921
]).
to
(
torch_device
)
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
logits
[
0
,
:
3
],
expected_slice
,
atol
=
1e-4
))
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
logits
[
0
,
:
3
],
expected_slice
,
atol
=
1e-4
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment