Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3b4d3d09
Unverified
Commit
3b4d3d09
authored
Jun 06, 2024
by
Alex Gorodnitskiy
Committed by
GitHub
Jun 06, 2024
Browse files
Fix SwinLayer / DonutSwinLayer / ClapAudioLayer attention mask device (#31295)
Fix DonutSwinLayer attention mask device
parent
b6c9f47f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
15 deletions
+15
-15
src/transformers/models/clap/modeling_clap.py
src/transformers/models/clap/modeling_clap.py
+5
-5
src/transformers/models/donut/modeling_donut_swin.py
src/transformers/models/donut/modeling_donut_swin.py
+5
-5
src/transformers/models/swin/modeling_swin.py
src/transformers/models/swin/modeling_swin.py
+5
-5
No files found.
src/transformers/models/clap/modeling_clap.py
View file @
3b4d3d09
...
@@ -593,10 +593,10 @@ class ClapAudioLayer(nn.Module):
...
@@ -593,10 +593,10 @@ class ClapAudioLayer(nn.Module):
self
.
shift_size
=
0
self
.
shift_size
=
0
self
.
window_size
=
min
(
input_resolution
)
self
.
window_size
=
min
(
input_resolution
)
def
get_attn_mask
(
self
,
height
,
width
,
dtype
):
def
get_attn_mask
(
self
,
height
,
width
,
dtype
,
device
):
if
self
.
shift_size
>
0
:
if
self
.
shift_size
>
0
:
# calculate attention mask for SW-MSA
# calculate attention mask for SW-MSA
img_mask
=
torch
.
zeros
((
1
,
height
,
width
,
1
),
dtype
=
dtype
)
img_mask
=
torch
.
zeros
((
1
,
height
,
width
,
1
),
dtype
=
dtype
,
device
=
device
)
height_slices
=
(
height_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
...
@@ -661,9 +661,9 @@ class ClapAudioLayer(nn.Module):
...
@@ -661,9 +661,9 @@ class ClapAudioLayer(nn.Module):
# partition windows
# partition windows
hidden_states_windows
=
window_partition
(
shifted_hidden_states
,
self
.
window_size
)
hidden_states_windows
=
window_partition
(
shifted_hidden_states
,
self
.
window_size
)
hidden_states_windows
=
hidden_states_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
channels
)
hidden_states_windows
=
hidden_states_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
channels
)
attn_mask
=
self
.
get_attn_mask
(
height_pad
,
width_pad
,
dtype
=
hidden_states
.
dtype
)
attn_mask
=
self
.
get_attn_mask
(
if
attn_mask
is
not
None
:
height_pad
,
width_pad
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states_windows
.
device
attn_mask
=
attn_mask
.
to
(
hidden_states_windows
.
device
)
)
attention_outputs
=
self
.
attention
(
attention_outputs
=
self
.
attention
(
hidden_states_windows
,
attn_mask
,
head_mask
,
output_attentions
=
output_attentions
hidden_states_windows
,
attn_mask
,
head_mask
,
output_attentions
=
output_attentions
...
...
src/transformers/models/donut/modeling_donut_swin.py
View file @
3b4d3d09
...
@@ -565,10 +565,10 @@ class DonutSwinLayer(nn.Module):
...
@@ -565,10 +565,10 @@ class DonutSwinLayer(nn.Module):
self
.
shift_size
=
0
self
.
shift_size
=
0
self
.
window_size
=
min
(
input_resolution
)
self
.
window_size
=
min
(
input_resolution
)
def
get_attn_mask
(
self
,
height
,
width
,
dtype
):
def
get_attn_mask
(
self
,
height
,
width
,
dtype
,
device
):
if
self
.
shift_size
>
0
:
if
self
.
shift_size
>
0
:
# calculate attention mask for SW-MSA
# calculate attention mask for SW-MSA
img_mask
=
torch
.
zeros
((
1
,
height
,
width
,
1
),
dtype
=
dtype
)
img_mask
=
torch
.
zeros
((
1
,
height
,
width
,
1
),
dtype
=
dtype
,
device
=
device
)
height_slices
=
(
height_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
...
@@ -633,9 +633,9 @@ class DonutSwinLayer(nn.Module):
...
@@ -633,9 +633,9 @@ class DonutSwinLayer(nn.Module):
# partition windows
# partition windows
hidden_states_windows
=
window_partition
(
shifted_hidden_states
,
self
.
window_size
)
hidden_states_windows
=
window_partition
(
shifted_hidden_states
,
self
.
window_size
)
hidden_states_windows
=
hidden_states_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
channels
)
hidden_states_windows
=
hidden_states_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
channels
)
attn_mask
=
self
.
get_attn_mask
(
height_pad
,
width_pad
,
dtype
=
hidden_states
.
dtype
)
attn_mask
=
self
.
get_attn_mask
(
if
attn_mask
is
not
None
:
height_pad
,
width_pad
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states_windows
.
device
attn_mask
=
attn_mask
.
to
(
hidden_states_windows
.
device
)
)
attention_outputs
=
self
.
attention
(
attention_outputs
=
self
.
attention
(
hidden_states_windows
,
attn_mask
,
head_mask
,
output_attentions
=
output_attentions
hidden_states_windows
,
attn_mask
,
head_mask
,
output_attentions
=
output_attentions
...
...
src/transformers/models/swin/modeling_swin.py
View file @
3b4d3d09
...
@@ -642,10 +642,10 @@ class SwinLayer(nn.Module):
...
@@ -642,10 +642,10 @@ class SwinLayer(nn.Module):
self
.
shift_size
=
0
self
.
shift_size
=
0
self
.
window_size
=
min
(
input_resolution
)
self
.
window_size
=
min
(
input_resolution
)
def
get_attn_mask
(
self
,
height
,
width
,
dtype
):
def
get_attn_mask
(
self
,
height
,
width
,
dtype
,
device
):
if
self
.
shift_size
>
0
:
if
self
.
shift_size
>
0
:
# calculate attention mask for SW-MSA
# calculate attention mask for SW-MSA
img_mask
=
torch
.
zeros
((
1
,
height
,
width
,
1
),
dtype
=
dtype
)
img_mask
=
torch
.
zeros
((
1
,
height
,
width
,
1
),
dtype
=
dtype
,
device
=
device
)
height_slices
=
(
height_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
...
@@ -710,9 +710,9 @@ class SwinLayer(nn.Module):
...
@@ -710,9 +710,9 @@ class SwinLayer(nn.Module):
# partition windows
# partition windows
hidden_states_windows
=
window_partition
(
shifted_hidden_states
,
self
.
window_size
)
hidden_states_windows
=
window_partition
(
shifted_hidden_states
,
self
.
window_size
)
hidden_states_windows
=
hidden_states_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
channels
)
hidden_states_windows
=
hidden_states_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
channels
)
attn_mask
=
self
.
get_attn_mask
(
height_pad
,
width_pad
,
dtype
=
hidden_states
.
dtype
)
attn_mask
=
self
.
get_attn_mask
(
if
attn_mask
is
not
None
:
height_pad
,
width_pad
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states_windows
.
device
attn_mask
=
attn_mask
.
to
(
hidden_states_windows
.
device
)
)
attention_outputs
=
self
.
attention
(
attention_outputs
=
self
.
attention
(
hidden_states_windows
,
attn_mask
,
head_mask
,
output_attentions
=
output_attentions
hidden_states_windows
,
attn_mask
,
head_mask
,
output_attentions
=
output_attentions
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment