Commit 1aec9405 authored by Rémi Louf's avatar Rémi Louf
Browse files

test the full story processing

parent 22e1af68
...@@ -87,9 +87,9 @@ class TextDataset(Dataset): ...@@ -87,9 +87,9 @@ class TextDataset(Dataset):
path_to_stories = os.path.join(data_dir, dataset, "stories") path_to_stories = os.path.join(data_dir, dataset, "stories")
assert os.path.isdir(path_to_stories) assert os.path.isdir(path_to_stories)
stories_files = os.listdir(path_to_stories) story_filenames_list = os.listdir(path_to_stories)
for story_file in stories_files: for story_filename in story_filenames_list:
path_to_story = os.path.join(path_to_stories, "story_file") path_to_story = os.path.join(path_to_stories, story_filename)
if not os.path.isfile(path_to_story): if not os.path.isfile(path_to_story):
continue continue
...@@ -97,16 +97,16 @@ class TextDataset(Dataset): ...@@ -97,16 +97,16 @@ class TextDataset(Dataset):
try: try:
raw_story = source.read() raw_story = source.read()
story, summary = process_story(raw_story) story, summary = process_story(raw_story)
except IndexError: except IndexError: # skip ill-formed stories
continue continue
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story)) story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary)) summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
story_seq, summary_seq = _fit_to_block_size(story, summary, block_size) story_seq, summary_seq = _fit_to_block_size(story, summary, block_size)
example = tokenizer.add_special_token_sequence_pair(
story_seq, summary_seq self.examples.append(
tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
) )
self.examples.append(example)
logger.info("Saving features into cache file %s", cached_features_file) logger.info("Saving features into cache file %s", cached_features_file)
with open(cached_features_file, "wb") as sink: with open(cached_features_file, "wb") as sink:
...@@ -120,8 +120,13 @@ class TextDataset(Dataset): ...@@ -120,8 +120,13 @@ class TextDataset(Dataset):
def process_story(raw_story): def process_story(raw_story):
""" Process the text contained in a story file. """ Extract the story and summary from a story file.
Returns the story and the summary
Attributes:
raw_story (str): content of the story file as an utf-8 encoded string.
Raises:
IndexError: If the stoy is empty or contains no highlights.
""" """
file_lines = list( file_lines = list(
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]) filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
...@@ -158,7 +163,7 @@ def _add_missing_period(line): ...@@ -158,7 +163,7 @@ def _add_missing_period(line):
return line return line
if line[-1] in END_TOKENS: if line[-1] in END_TOKENS:
return line return line
return line + " ." return line + "."
def _fit_to_block_size(src_sequence, tgt_sequence, block_size): def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
...@@ -169,6 +174,13 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size): ...@@ -169,6 +174,13 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
block size of 512 this means limiting the source sequence's length to 384 block size of 512 this means limiting the source sequence's length to 384
and the target sequence's length to 128. and the target sequence's length to 128.
Attributes:
src_sequence (list): a list of ids that maps to the tokens of the
source sequence.
tgt_sequence (list): a list of ids that maps to the tokens of the
target sequence.
block_size (int): the model's block size.
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural [1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019). Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
""" """
......
...@@ -14,21 +14,21 @@ ...@@ -14,21 +14,21 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
from run_seq2seq_finetuning import _fit_to_block_size from run_seq2seq_finetuning import _fit_to_block_size, process_story
class DataLoaderTest(unittest.TestCase): class DataLoaderTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.block_size = 10 self.block_size = 10
def test_source_and_target_too_small(self): def test_truncate_source_and_target_too_small(self):
""" When the sum of the lengths of the source and target sequences is """ When the sum of the lengths of the source and target sequences is
smaller than the block size (minus the number of special tokens), skip the example. """ smaller than the block size (minus the number of special tokens), skip the example. """
src_seq = [1, 2, 3, 4] src_seq = [1, 2, 3, 4]
tgt_seq = [5, 6] tgt_seq = [5, 6]
self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None) self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None)
def test_source_and_target_fit_exactly(self): def test_truncate_source_and_target_fit_exactly(self):
""" When the sum of the lengths of the source and target sequences is """ When the sum of the lengths of the source and target sequences is
equal to the block size (minus the number of special tokens), return the equal to the block size (minus the number of special tokens), return the
sequences unchanged. """ sequences unchanged. """
...@@ -38,27 +38,61 @@ class DataLoaderTest(unittest.TestCase): ...@@ -38,27 +38,61 @@ class DataLoaderTest(unittest.TestCase):
self.assertListEqual(src_seq, fitted_src) self.assertListEqual(src_seq, fitted_src)
self.assertListEqual(tgt_seq, fitted_tgt) self.assertListEqual(tgt_seq, fitted_tgt)
def test_source_too_big_target_ok(self): def test_truncate_source_too_big_target_ok(self):
src_seq = [1, 2, 3, 4, 5, 6] src_seq = [1, 2, 3, 4, 5, 6]
tgt_seq = [1, 2] tgt_seq = [1, 2]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5]) self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
self.assertListEqual(fitted_tgt, fitted_tgt) self.assertListEqual(fitted_tgt, fitted_tgt)
def test_target_too_big_source_ok(self): def test_truncate_target_too_big_source_ok(self):
src_seq = [1, 2, 3, 4] src_seq = [1, 2, 3, 4]
tgt_seq = [1, 2, 3, 4] tgt_seq = [1, 2, 3, 4]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(fitted_src, src_seq) self.assertListEqual(fitted_src, src_seq)
self.assertListEqual(fitted_tgt, [1, 2, 3]) self.assertListEqual(fitted_tgt, [1, 2, 3])
def test_source_and_target_too_big(self): def test_truncate_source_and_target_too_big(self):
src_seq = [1, 2, 3, 4, 5, 6, 7] src_seq = [1, 2, 3, 4, 5, 6, 7]
tgt_seq = [1, 2, 3, 4, 5, 6, 7] tgt_seq = [1, 2, 3, 4, 5, 6, 7]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5]) self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
self.assertListEqual(fitted_tgt, [1, 2]) self.assertListEqual(fitted_tgt, [1, 2])
def test_process_story_no_highlights(self):
""" Processing a story with no highlights should raise an exception.
"""
raw_story = """It was the year of Our Lord one thousand seven hundred and
seventy-five.\n\nSpiritual revelations were conceded to England at that
favoured period, as at this."""
with self.assertRaises(IndexError):
process_story(raw_story)
def test_process_empty_story(self):
""" An empty story should also raise and exception.
"""
raw_story = ""
with self.assertRaises(IndexError):
process_story(raw_story)
def test_story_with_missing_period(self):
raw_story = (
"It was the year of Our Lord one thousand seven hundred and "
"seventy-five\n\nSpiritual revelations were conceded to England "
"at that favoured period, as at this.\n@highlight\n\nIt was the best of times"
)
story, summary = process_story(raw_story)
expected_story = (
"It was the year of Our Lord one thousand seven hundred and "
"seventy-five. Spiritual revelations were conceded to England at that "
"favoured period, as at this."
)
self.assertEqual(expected_story, story)
expected_summary = "It was the best of times."
self.assertEqual(expected_summary, summary)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment