Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1aec9405
Commit
1aec9405
authored
Oct 15, 2019
by
Rémi Louf
Browse files
test the full story processing
parent
22e1af68
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
16 deletions
+62
-16
examples/run_seq2seq_finetuning.py
examples/run_seq2seq_finetuning.py
+22
-10
examples/run_seq2seq_finetuning_test.py
examples/run_seq2seq_finetuning_test.py
+40
-6
No files found.
examples/run_seq2seq_finetuning.py
View file @
1aec9405
...
...
@@ -87,9 +87,9 @@ class TextDataset(Dataset):
path_to_stories
=
os
.
path
.
join
(
data_dir
,
dataset
,
"stories"
)
assert
os
.
path
.
isdir
(
path_to_stories
)
stor
ies_files
=
os
.
listdir
(
path_to_stories
)
for
story_file
in
stor
ies_files
:
path_to_story
=
os
.
path
.
join
(
path_to_stories
,
"
story_file
"
)
stor
y_filenames_list
=
os
.
listdir
(
path_to_stories
)
for
story_file
name
in
stor
y_filenames_list
:
path_to_story
=
os
.
path
.
join
(
path_to_stories
,
story_file
name
)
if
not
os
.
path
.
isfile
(
path_to_story
):
continue
...
...
@@ -97,16 +97,16 @@ class TextDataset(Dataset):
try
:
raw_story
=
source
.
read
()
story
,
summary
=
process_story
(
raw_story
)
except
IndexError
:
except
IndexError
:
# skip ill-formed stories
continue
story
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
story
))
summary
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
summary
))
story_seq
,
summary_seq
=
_fit_to_block_size
(
story
,
summary
,
block_size
)
example
=
tokenizer
.
add_special_token_sequence_pair
(
story_seq
,
summary_seq
self
.
examples
.
append
(
tokenizer
.
add_special_token_sequence_pair
(
story_seq
,
summary_seq
)
)
self
.
examples
.
append
(
example
)
logger
.
info
(
"Saving features into cache file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
"wb"
)
as
sink
:
...
...
@@ -120,8 +120,13 @@ class TextDataset(Dataset):
def
process_story
(
raw_story
):
""" Process the text contained in a story file.
Returns the story and the summary
""" Extract the story and summary from a story file.
Attributes:
raw_story (str): content of the story file as an utf-8 encoded string.
Raises:
IndexError: If the stoy is empty or contains no highlights.
"""
file_lines
=
list
(
filter
(
lambda
x
:
len
(
x
)
!=
0
,
[
line
.
strip
()
for
line
in
raw_story
.
split
(
"
\n
"
)])
...
...
@@ -158,7 +163,7 @@ def _add_missing_period(line):
return
line
if
line
[
-
1
]
in
END_TOKENS
:
return
line
return
line
+
"
."
return
line
+
"."
def
_fit_to_block_size
(
src_sequence
,
tgt_sequence
,
block_size
):
...
...
@@ -169,6 +174,13 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
block size of 512 this means limiting the source sequence's length to 384
and the target sequence's length to 128.
Attributes:
src_sequence (list): a list of ids that maps to the tokens of the
source sequence.
tgt_sequence (list): a list of ids that maps to the tokens of the
target sequence.
block_size (int): the model's block size.
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
"""
...
...
examples/run_seq2seq_finetuning_test.py
View file @
1aec9405
...
...
@@ -14,21 +14,21 @@
# limitations under the License.
import
unittest
from
run_seq2seq_finetuning
import
_fit_to_block_size
from
run_seq2seq_finetuning
import
_fit_to_block_size
,
process_story
class
DataLoaderTest
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
block_size
=
10
def
test_source_and_target_too_small
(
self
):
def
test_
truncate_
source_and_target_too_small
(
self
):
""" When the sum of the lengths of the source and target sequences is
smaller than the block size (minus the number of special tokens), skip the example. """
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
5
,
6
]
self
.
assertEqual
(
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
),
None
)
def
test_source_and_target_fit_exactly
(
self
):
def
test_
truncate_
source_and_target_fit_exactly
(
self
):
""" When the sum of the lengths of the source and target sequences is
equal to the block size (minus the number of special tokens), return the
sequences unchanged. """
...
...
@@ -38,27 +38,61 @@ class DataLoaderTest(unittest.TestCase):
self
.
assertListEqual
(
src_seq
,
fitted_src
)
self
.
assertListEqual
(
tgt_seq
,
fitted_tgt
)
def
test_source_too_big_target_ok
(
self
):
def
test_
truncate_
source_too_big_target_ok
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
]
tgt_seq
=
[
1
,
2
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
fitted_src
,
[
1
,
2
,
3
,
4
,
5
])
self
.
assertListEqual
(
fitted_tgt
,
fitted_tgt
)
def
test_target_too_big_source_ok
(
self
):
def
test_
truncate_
target_too_big_source_ok
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
1
,
2
,
3
,
4
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
fitted_src
,
src_seq
)
self
.
assertListEqual
(
fitted_tgt
,
[
1
,
2
,
3
])
def
test_source_and_target_too_big
(
self
):
def
test_
truncate_
source_and_target_too_big
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]
tgt_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
fitted_src
,
[
1
,
2
,
3
,
4
,
5
])
self
.
assertListEqual
(
fitted_tgt
,
[
1
,
2
])
def
test_process_story_no_highlights
(
self
):
""" Processing a story with no highlights should raise an exception.
"""
raw_story
=
"""It was the year of Our Lord one thousand seven hundred and
seventy-five.
\n\n
Spiritual revelations were conceded to England at that
favoured period, as at this."""
with
self
.
assertRaises
(
IndexError
):
process_story
(
raw_story
)
def
test_process_empty_story
(
self
):
""" An empty story should also raise and exception.
"""
raw_story
=
""
with
self
.
assertRaises
(
IndexError
):
process_story
(
raw_story
)
def
test_story_with_missing_period
(
self
):
raw_story
=
(
"It was the year of Our Lord one thousand seven hundred and "
"seventy-five
\n\n
Spiritual revelations were conceded to England "
"at that favoured period, as at this.
\n
@highlight
\n\n
It was the best of times"
)
story
,
summary
=
process_story
(
raw_story
)
expected_story
=
(
"It was the year of Our Lord one thousand seven hundred and "
"seventy-five. Spiritual revelations were conceded to England at that "
"favoured period, as at this."
)
self
.
assertEqual
(
expected_story
,
story
)
expected_summary
=
"It was the best of times."
self
.
assertEqual
(
expected_summary
,
summary
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment