Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
932543f7
Commit
932543f7
authored
Oct 17, 2019
by
Rémi Louf
Browse files
fix test of truncation function
parent
a67413cc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
35 deletions
+13
-35
examples/run_seq2seq_finetuning_test.py
examples/run_seq2seq_finetuning_test.py
+13
-35
No files found.
examples/run_seq2seq_finetuning_test.py
View file @
932543f7
...
...
@@ -21,43 +21,21 @@ class DataLoaderTest(unittest.TestCase):
def
setUp
(
self
):
self
.
block_size
=
10
def
test_truncate_source_and_target_too_small
(
self
):
""" When the sum of the lengths of the source and target sequences is
smaller than the block size (minus the number of special tokens), skip the example. """
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
5
,
6
]
self
.
assertEqual
(
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
),
None
)
def
test_truncate_sequence_too_small
(
self
):
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
sequence
=
[
1
,
2
,
3
,
4
]
expected_output
=
[
1
,
2
,
3
,
4
,
0
,
0
,
0
,
0
,
0
,
0
]
self
.
assertEqual
(
_fit_to_block_size
(
sequence
,
self
.
block_size
),
expected_output
)
def
test_truncate_source_and_target_fit_exactly
(
self
):
""" When the sum of the lengths of the source and target sequences is
equal to the block size (minus the number of special tokens), return the
sequences unchanged. """
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
5
,
6
,
7
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
src_seq
,
fitted_src
)
self
.
assertListEqual
(
tgt_seq
,
fitted_tgt
)
def
test_truncate_sequence_fit_exactly
(
self
):
sequence
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
]
expected_output
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
]
self
.
assertEqual
(
_fit_to_block_size
(
sequence
,
self
.
block_size
),
expected_output
)
def
test_truncate_source_too_big_target_ok
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
]
tgt_seq
=
[
1
,
2
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
fitted_src
,
[
1
,
2
,
3
,
4
,
5
])
self
.
assertListEqual
(
fitted_tgt
,
fitted_tgt
)
def
test_truncate_target_too_big_source_ok
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
1
,
2
,
3
,
4
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
fitted_src
,
src_seq
)
self
.
assertListEqual
(
fitted_tgt
,
[
1
,
2
,
3
])
def
test_truncate_source_and_target_too_big
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]
tgt_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
fitted_src
,
[
1
,
2
,
3
,
4
,
5
])
self
.
assertListEqual
(
fitted_tgt
,
[
1
,
2
])
def
test_truncate_sequence_too_big
(
self
):
sequence
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
]
expected_output
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
]
self
.
assertEqual
(
_fit_to_block_size
(
sequence
,
self
.
block_size
),
expected_output
)
def
test_process_story_no_highlights
(
self
):
""" Processing a story with no highlights should raise an exception.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment