Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
070507df
Commit
070507df
authored
Oct 30, 2019
by
Rémi Louf
Browse files
format utils for summarization
parent
da10de84
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
7 deletions
+3
-7
examples/utils_summarization.py
examples/utils_summarization.py
+1
-1
examples/utils_summarization_test.py
examples/utils_summarization_test.py
+2
-6
No files found.
examples/utils_summarization.py
View file @
070507df
...
...
@@ -128,7 +128,7 @@ def build_mask(sequence, pad_token):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
mask
=
torch
.
ones_like
(
sequence
)
idx_pad_tokens
=
(
sequence
==
pad_token
)
idx_pad_tokens
=
sequence
==
pad_token
mask
[
idx_pad_tokens
]
=
0
return
mask
...
...
examples/utils_summarization_test.py
View file @
070507df
...
...
@@ -105,9 +105,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
def
test_build_mask_no_padding
(
self
):
sequence
=
torch
.
tensor
([
1
,
2
,
3
,
4
])
expected
=
torch
.
tensor
([
1
,
1
,
1
,
1
])
np
.
testing
.
assert_array_equal
(
build_mask
(
sequence
,
0
).
numpy
(),
expected
.
numpy
()
)
np
.
testing
.
assert_array_equal
(
build_mask
(
sequence
,
0
).
numpy
(),
expected
.
numpy
())
def
test_build_mask
(
self
):
sequence
=
torch
.
tensor
([
1
,
2
,
3
,
4
,
23
,
23
,
23
])
...
...
@@ -119,9 +117,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
def
test_build_mask_with_padding_equal_to_one
(
self
):
sequence
=
torch
.
tensor
([
8
,
2
,
3
,
4
,
1
,
1
,
1
])
expected
=
torch
.
tensor
([
1
,
1
,
1
,
1
,
0
,
0
,
0
])
np
.
testing
.
assert_array_equal
(
build_mask
(
sequence
,
1
).
numpy
(),
expected
.
numpy
()
)
np
.
testing
.
assert_array_equal
(
build_mask
(
sequence
,
1
).
numpy
(),
expected
.
numpy
())
def
test_compute_token_type_ids
(
self
):
separator
=
101
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment