Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
932543f7
Commit
932543f7
authored
Oct 17, 2019
by
Rémi Louf
Browse files
fix test of truncation function
parent
a67413cc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
35 deletions
+13
-35
examples/run_seq2seq_finetuning_test.py
examples/run_seq2seq_finetuning_test.py
+13
-35
No files found.
examples/run_seq2seq_finetuning_test.py
View file @
932543f7
...
...
@@ -21,43 +21,21 @@ class DataLoaderTest(unittest.TestCase):
def
setUp
(
self
):
self
.
block_size
=
10
def
test_truncate_source_and_target_too_small
(
self
):
""" When the sum of the lengths of the source and target sequences is
smaller than the block size (minus the number of special tokens), skip the example. """
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
5
,
6
]
self
.
assertEqual
(
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
),
None
)
def
test_truncate_sequence_too_small
(
self
):
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
sequence
=
[
1
,
2
,
3
,
4
]
expected_output
=
[
1
,
2
,
3
,
4
,
0
,
0
,
0
,
0
,
0
,
0
]
self
.
assertEqual
(
_fit_to_block_size
(
sequence
,
self
.
block_size
),
expected_output
)
def
test_truncate_source_and_target_fit_exactly
(
self
):
""" When the sum of the lengths of the source and target sequences is
equal to the block size (minus the number of special tokens), return the
sequences unchanged. """
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
5
,
6
,
7
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
src_seq
,
fitted_src
)
self
.
assertListEqual
(
tgt_seq
,
fitted_tgt
)
def
test_truncate_sequence_fit_exactly
(
self
):
sequence
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
]
expected_output
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
]
self
.
assertEqual
(
_fit_to_block_size
(
sequence
,
self
.
block_size
),
expected_output
)
def
test_truncate_source_too_big_target_ok
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
]
tgt_seq
=
[
1
,
2
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
fitted_src
,
[
1
,
2
,
3
,
4
,
5
])
self
.
assertListEqual
(
fitted_tgt
,
fitted_tgt
)
def
test_truncate_target_too_big_source_ok
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
1
,
2
,
3
,
4
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
fitted_src
,
src_seq
)
self
.
assertListEqual
(
fitted_tgt
,
[
1
,
2
,
3
])
def
test_truncate_source_and_target_too_big
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]
tgt_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
fitted_src
,
[
1
,
2
,
3
,
4
,
5
])
self
.
assertListEqual
(
fitted_tgt
,
[
1
,
2
])
def
test_truncate_sequence_too_big
(
self
):
sequence
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
]
expected_output
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
]
self
.
assertEqual
(
_fit_to_block_size
(
sequence
,
self
.
block_size
),
expected_output
)
def
test_process_story_no_highlights
(
self
):
""" Processing a story with no highlights should raise an exception.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment