Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d24097e0
Unverified
Commit
d24097e0
authored
May 21, 2024
by
amyeroberts
Committed by
GitHub
May 21, 2024
Browse files
Fix swin embeddings interpolation (#30936)
parent
eae2b6b8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
62 deletions
+10
-62
src/transformers/models/donut/modeling_donut_swin.py
src/transformers/models/donut/modeling_donut_swin.py
+3
-16
src/transformers/models/maskformer/modeling_maskformer_swin.py
...ransformers/models/maskformer/modeling_maskformer_swin.py
+3
-16
src/transformers/models/swin/modeling_swin.py
src/transformers/models/swin/modeling_swin.py
+2
-15
src/transformers/models/swinv2/modeling_swinv2.py
src/transformers/models/swinv2/modeling_swinv2.py
+2
-15
No files found.
src/transformers/models/donut/modeling_donut_swin.py
View file @
d24097e0
...
@@ -205,9 +205,7 @@ class DonutSwinEmbeddings(nn.Module):
...
@@ -205,9 +205,7 @@ class DonutSwinEmbeddings(nn.Module):
interpolate_pos_encoding
:
bool
=
False
,
interpolate_pos_encoding
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
]:
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
embeddings
,
output_dimensions
=
self
.
patch_embeddings
(
embeddings
,
output_dimensions
=
self
.
patch_embeddings
(
pixel_values
)
pixel_values
,
interpolate_pos_encoding
=
interpolate_pos_encoding
)
embeddings
=
self
.
norm
(
embeddings
)
embeddings
=
self
.
norm
(
embeddings
)
batch_size
,
seq_len
,
_
=
embeddings
.
size
()
batch_size
,
seq_len
,
_
=
embeddings
.
size
()
...
@@ -228,7 +226,7 @@ class DonutSwinEmbeddings(nn.Module):
...
@@ -228,7 +226,7 @@ class DonutSwinEmbeddings(nn.Module):
return
embeddings
,
output_dimensions
return
embeddings
,
output_dimensions
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
with Swin->DonutSwin
class
DonutSwinPatchEmbeddings
(
nn
.
Module
):
class
DonutSwinPatchEmbeddings
(
nn
.
Module
):
"""
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
...
@@ -260,21 +258,10 @@ class DonutSwinPatchEmbeddings(nn.Module):
...
@@ -260,21 +258,10 @@ class DonutSwinPatchEmbeddings(nn.Module):
pixel_values
=
nn
.
functional
.
pad
(
pixel_values
,
pad_values
)
pixel_values
=
nn
.
functional
.
pad
(
pixel_values
,
pad_values
)
return
pixel_values
return
pixel_values
def
forward
(
def
forward
(
self
,
pixel_values
:
Optional
[
torch
.
FloatTensor
])
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
]]:
self
,
pixel_values
:
Optional
[
torch
.
FloatTensor
],
interpolate_pos_encoding
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
]]:
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
if
num_channels
!=
self
.
num_channels
:
raise
ValueError
(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed
# pad the input to be divisible by self.patch_size, if needed
pixel_values
=
self
.
maybe_pad
(
pixel_values
,
height
,
width
)
pixel_values
=
self
.
maybe_pad
(
pixel_values
,
height
,
width
)
if
not
interpolate_pos_encoding
and
(
height
!=
self
.
image_size
[
0
]
or
width
!=
self
.
image_size
[
1
]):
raise
ValueError
(
f
"Input image size (
{
height
}
*
{
width
}
) doesn't match model"
f
" (
{
self
.
image_size
[
0
]
}
*
{
self
.
image_size
[
1
]
}
)."
)
embeddings
=
self
.
projection
(
pixel_values
)
embeddings
=
self
.
projection
(
pixel_values
)
_
,
_
,
height
,
width
=
embeddings
.
shape
_
,
_
,
height
,
width
=
embeddings
.
shape
output_dimensions
=
(
height
,
width
)
output_dimensions
=
(
height
,
width
)
...
...
src/transformers/models/maskformer/modeling_maskformer_swin.py
View file @
d24097e0
...
@@ -197,9 +197,7 @@ class MaskFormerSwinEmbeddings(nn.Module):
...
@@ -197,9 +197,7 @@ class MaskFormerSwinEmbeddings(nn.Module):
def
forward
(
self
,
pixel_values
,
interpolate_pos_encoding
):
def
forward
(
self
,
pixel_values
,
interpolate_pos_encoding
):
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
embeddings
,
output_dimensions
=
self
.
patch_embeddings
(
embeddings
,
output_dimensions
=
self
.
patch_embeddings
(
pixel_values
)
pixel_values
,
interpolate_pos_encoding
=
interpolate_pos_encoding
)
embeddings
=
self
.
norm
(
embeddings
)
embeddings
=
self
.
norm
(
embeddings
)
if
self
.
position_embeddings
is
not
None
:
if
self
.
position_embeddings
is
not
None
:
...
@@ -213,7 +211,7 @@ class MaskFormerSwinEmbeddings(nn.Module):
...
@@ -213,7 +211,7 @@ class MaskFormerSwinEmbeddings(nn.Module):
return
embeddings
,
output_dimensions
return
embeddings
,
output_dimensions
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
with Swin->MaskFormerSwin
class
MaskFormerSwinPatchEmbeddings
(
nn
.
Module
):
class
MaskFormerSwinPatchEmbeddings
(
nn
.
Module
):
"""
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
...
@@ -245,21 +243,10 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
...
@@ -245,21 +243,10 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
pixel_values
=
nn
.
functional
.
pad
(
pixel_values
,
pad_values
)
pixel_values
=
nn
.
functional
.
pad
(
pixel_values
,
pad_values
)
return
pixel_values
return
pixel_values
def
forward
(
def
forward
(
self
,
pixel_values
:
Optional
[
torch
.
FloatTensor
])
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
]]:
self
,
pixel_values
:
Optional
[
torch
.
FloatTensor
],
interpolate_pos_encoding
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
]]:
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
if
num_channels
!=
self
.
num_channels
:
raise
ValueError
(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed
# pad the input to be divisible by self.patch_size, if needed
pixel_values
=
self
.
maybe_pad
(
pixel_values
,
height
,
width
)
pixel_values
=
self
.
maybe_pad
(
pixel_values
,
height
,
width
)
if
not
interpolate_pos_encoding
and
(
height
!=
self
.
image_size
[
0
]
or
width
!=
self
.
image_size
[
1
]):
raise
ValueError
(
f
"Input image size (
{
height
}
*
{
width
}
) doesn't match model"
f
" (
{
self
.
image_size
[
0
]
}
*
{
self
.
image_size
[
1
]
}
)."
)
embeddings
=
self
.
projection
(
pixel_values
)
embeddings
=
self
.
projection
(
pixel_values
)
_
,
_
,
height
,
width
=
embeddings
.
shape
_
,
_
,
height
,
width
=
embeddings
.
shape
output_dimensions
=
(
height
,
width
)
output_dimensions
=
(
height
,
width
)
...
...
src/transformers/models/swin/modeling_swin.py
View file @
d24097e0
...
@@ -291,9 +291,7 @@ class SwinEmbeddings(nn.Module):
...
@@ -291,9 +291,7 @@ class SwinEmbeddings(nn.Module):
interpolate_pos_encoding
:
bool
=
False
,
interpolate_pos_encoding
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
]:
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
embeddings
,
output_dimensions
=
self
.
patch_embeddings
(
embeddings
,
output_dimensions
=
self
.
patch_embeddings
(
pixel_values
)
pixel_values
,
interpolate_pos_encoding
=
interpolate_pos_encoding
)
embeddings
=
self
.
norm
(
embeddings
)
embeddings
=
self
.
norm
(
embeddings
)
batch_size
,
seq_len
,
_
=
embeddings
.
size
()
batch_size
,
seq_len
,
_
=
embeddings
.
size
()
...
@@ -345,21 +343,10 @@ class SwinPatchEmbeddings(nn.Module):
...
@@ -345,21 +343,10 @@ class SwinPatchEmbeddings(nn.Module):
pixel_values
=
nn
.
functional
.
pad
(
pixel_values
,
pad_values
)
pixel_values
=
nn
.
functional
.
pad
(
pixel_values
,
pad_values
)
return
pixel_values
return
pixel_values
def
forward
(
def
forward
(
self
,
pixel_values
:
Optional
[
torch
.
FloatTensor
])
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
]]:
self
,
pixel_values
:
Optional
[
torch
.
FloatTensor
],
interpolate_pos_encoding
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
]]:
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
if
num_channels
!=
self
.
num_channels
:
raise
ValueError
(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed
# pad the input to be divisible by self.patch_size, if needed
pixel_values
=
self
.
maybe_pad
(
pixel_values
,
height
,
width
)
pixel_values
=
self
.
maybe_pad
(
pixel_values
,
height
,
width
)
if
not
interpolate_pos_encoding
and
(
height
!=
self
.
image_size
[
0
]
or
width
!=
self
.
image_size
[
1
]):
raise
ValueError
(
f
"Input image size (
{
height
}
*
{
width
}
) doesn't match model"
f
" (
{
self
.
image_size
[
0
]
}
*
{
self
.
image_size
[
1
]
}
)."
)
embeddings
=
self
.
projection
(
pixel_values
)
embeddings
=
self
.
projection
(
pixel_values
)
_
,
_
,
height
,
width
=
embeddings
.
shape
_
,
_
,
height
,
width
=
embeddings
.
shape
output_dimensions
=
(
height
,
width
)
output_dimensions
=
(
height
,
width
)
...
...
src/transformers/models/swinv2/modeling_swinv2.py
View file @
d24097e0
...
@@ -334,9 +334,7 @@ class Swinv2Embeddings(nn.Module):
...
@@ -334,9 +334,7 @@ class Swinv2Embeddings(nn.Module):
interpolate_pos_encoding
:
bool
=
False
,
interpolate_pos_encoding
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
]:
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
embeddings
,
output_dimensions
=
self
.
patch_embeddings
(
embeddings
,
output_dimensions
=
self
.
patch_embeddings
(
pixel_values
)
pixel_values
,
interpolate_pos_encoding
=
interpolate_pos_encoding
)
embeddings
=
self
.
norm
(
embeddings
)
embeddings
=
self
.
norm
(
embeddings
)
batch_size
,
seq_len
,
_
=
embeddings
.
size
()
batch_size
,
seq_len
,
_
=
embeddings
.
size
()
...
@@ -389,21 +387,10 @@ class Swinv2PatchEmbeddings(nn.Module):
...
@@ -389,21 +387,10 @@ class Swinv2PatchEmbeddings(nn.Module):
pixel_values
=
nn
.
functional
.
pad
(
pixel_values
,
pad_values
)
pixel_values
=
nn
.
functional
.
pad
(
pixel_values
,
pad_values
)
return
pixel_values
return
pixel_values
def
forward
(
def
forward
(
self
,
pixel_values
:
Optional
[
torch
.
FloatTensor
])
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
]]:
self
,
pixel_values
:
Optional
[
torch
.
FloatTensor
],
interpolate_pos_encoding
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
]]:
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
if
num_channels
!=
self
.
num_channels
:
raise
ValueError
(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed
# pad the input to be divisible by self.patch_size, if needed
pixel_values
=
self
.
maybe_pad
(
pixel_values
,
height
,
width
)
pixel_values
=
self
.
maybe_pad
(
pixel_values
,
height
,
width
)
if
not
interpolate_pos_encoding
and
(
height
!=
self
.
image_size
[
0
]
or
width
!=
self
.
image_size
[
1
]):
raise
ValueError
(
f
"Input image size (
{
height
}
*
{
width
}
) doesn't match model"
f
" (
{
self
.
image_size
[
0
]
}
*
{
self
.
image_size
[
1
]
}
)."
)
embeddings
=
self
.
projection
(
pixel_values
)
embeddings
=
self
.
projection
(
pixel_values
)
_
,
_
,
height
,
width
=
embeddings
.
shape
_
,
_
,
height
,
width
=
embeddings
.
shape
output_dimensions
=
(
height
,
width
)
output_dimensions
=
(
height
,
width
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment