Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
da10de84
Commit
da10de84
authored
Oct 30, 2019
by
Rémi Louf
Browse files
fix bug with padding mask + add corresponding test
parent
3b0d2fa3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
3 deletions
+10
-3
examples/utils_summarization.py
examples/utils_summarization.py
+3
-3
examples/utils_summarization_test.py
examples/utils_summarization_test.py
+7
-0
No files found.
examples/utils_summarization.py
View file @
da10de84
...
...
@@ -127,9 +127,9 @@ def build_lm_labels(sequence, pad_token):
def
build_mask
(
sequence
,
pad_token
):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
mask
=
sequence
.
clone
(
)
mask
[
mask
!
=
pad_token
]
=
1
mask
[
mask
==
pad_token
]
=
0
mask
=
torch
.
ones_like
(
sequence
)
idx_pad_tokens
=
(
sequence
=
=
pad_token
)
mask
[
idx_
pad_token
s
]
=
0
return
mask
...
...
examples/utils_summarization_test.py
View file @
da10de84
...
...
@@ -116,6 +116,13 @@ class SummarizationDataProcessingTest(unittest.TestCase):
build_mask
(
sequence
,
23
).
numpy
(),
expected
.
numpy
()
)
def
test_build_mask_with_padding_equal_to_one
(
self
):
sequence
=
torch
.
tensor
([
8
,
2
,
3
,
4
,
1
,
1
,
1
])
expected
=
torch
.
tensor
([
1
,
1
,
1
,
1
,
0
,
0
,
0
])
np
.
testing
.
assert_array_equal
(
build_mask
(
sequence
,
1
).
numpy
(),
expected
.
numpy
()
)
def
test_compute_token_type_ids
(
self
):
separator
=
101
batch
=
torch
.
tensor
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment