Commit 260ac7d9 authored by Rémi Louf's avatar Rémi Louf
Browse files

wip commit, switching computers

parent fe25eefc
...@@ -31,7 +31,7 @@ Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197 ...@@ -31,7 +31,7 @@ Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
""" """
import argparse import argparse
import dequeue from collections import deque
import logging import logging
import pickle import pickle
import random import random
...@@ -57,9 +57,9 @@ class TextDataset(Dataset): ...@@ -57,9 +57,9 @@ class TextDataset(Dataset):
CNN/Daily News: CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. They consist in stories stored The CNN/Daily News raw datasets are downloaded from [1]. The stories are stored in different files; the summary appears at the end of the story as
in different files where the summary sentences are indicated by the special `@highlight` token. sentences that are prefixed by the special `@highlight` line. To process the
To process the data, untar both datasets in the same folder, and pass the path to this data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2]. folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/ [1] https://cs.nyu.edu/~kcho/
...@@ -69,7 +69,7 @@ class TextDataset(Dataset): ...@@ -69,7 +69,7 @@ class TextDataset(Dataset):
assert os.path.isdir(data_dir) assert os.path.isdir(data_dir)
# Load features that have already been computed if present # Load features that have already been computed if present
cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, data_dir) cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, data_dir))
if os.path.exists(cached_features_file): if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
with open(cached_features_file, "rb") as source: with open(cached_features_file, "rb") as source:
...@@ -86,18 +86,19 @@ class TextDataset(Dataset): ...@@ -86,18 +86,19 @@ class TextDataset(Dataset):
stories_files = os.listdir(path_to_stories) stories_files = os.listdir(path_to_stories)
for story_file in stories_files: for story_file in stories_files:
path_to_story = os.path.join(path_to_stories, "story_file") path_to_story = os.path.join(path_to_stories, "story_file")
if !os.path.isfile(path_to_story): if not os.path.isfile(path_to_story):
continue continue
with open(path_to_story, encoding="utf-8") as source: with open(path_to_story, encoding="utf-8") as source:
try: try:
story, summary = process_story(source) raw_story = source.read()
story, summary = process_story(raw_story)
except IndexError: except IndexError:
continue continue
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story)) story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary)) summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
story_seq, summary_seq = _fit_to_block_size(story, summary, blocksize) story_seq, summary_seq = _fit_to_block_size(story, summary, block_size)
example = tokenizer.add_special_token_sequence_pair(story_seq, summary_seq) example = tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
self.examples.append(example) self.examples.append(example)
...@@ -108,22 +109,22 @@ class TextDataset(Dataset): ...@@ -108,22 +109,22 @@ class TextDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.examples) return len(self.examples)
def __getitem__(self): def __getitem__(self, items):
return torch.tensor(self.examples[items]) return torch.tensor(self.examples[items])
def process_story(story_file): def process_story(raw_story):
""" Process the text contained in a story file. """ Process the text contained in a story file.
Returns the story and the summary Returns the story and the summary
""" """
file_lines = list(filter(lambda x: len(x)!=0, [line.strip() for lines in story_file])) file_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
# for some unknown reason some lines miss a period, add it # for some unknown reason some lines miss a period, add it
file_lines = [_add_missing_period(line) for line in file_lines] file_lines = [_add_missing_period(line) for line in file_lines]
# gather article lines # gather article lines
story_lines = [] story_lines = []
lines = dequeue(file_lines) lines = deque(file_lines)
while True: while True:
try: try:
element = lines.popleft() element = lines.popleft()
...@@ -134,7 +135,7 @@ def process_story(story_file): ...@@ -134,7 +135,7 @@ def process_story(story_file):
raise ie raise ie
# gather summary lines # gather summary lines
highlights_lines = list(filter(lambda t: !t.startswith("@highlight"), lines)) highlights_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
# join the lines # join the lines
story = " ".join(story_lines) story = " ".join(story_lines)
...@@ -145,7 +146,7 @@ def process_story(story_file): ...@@ -145,7 +146,7 @@ def process_story(story_file):
def _add_missing_period(line): def _add_missing_period(line):
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', u'\u2019', u'\u2019', ")"] END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', u'\u2019', u'\u2019', ")"]
if line == "@highlight": if line.startswith("@highlight"):
return line return line
if line[-1] in END_TOKENS: if line[-1] in END_TOKENS:
return line return line
...@@ -163,8 +164,8 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size): ...@@ -163,8 +164,8 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural [1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019). Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
""" """
SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token
TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token
# we dump the examples that are too small to fit in the block size for the # we dump the examples that are too small to fit in the block size for the
# sake of simplicity. You can modify this by adding model-specific padding. # sake of simplicity. You can modify this by adding model-specific padding.
...@@ -172,22 +173,21 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size): ...@@ -172,22 +173,21 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
return None return None
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now. # the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
if len(src_sequence) > SRC_MAX_LENGTH if len(src_sequence) > SRC_MAX_LENGTH:
if len(tgt_sequence) > TGT_MAX_LENGTH: if len(tgt_sequence) > TGT_MAX_LENGTH:
src_sequence = src_sequence[:SRC_MAX_LENGTH] src_sequence = src_sequence[:SRC_MAX_LENGTH]
tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH] tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH]
else: else:
src_sequence = src_sequence[block_size - len(tgt_sequence) - 3] src_sequence = src_sequence[block_size - len(tgt_sequence) - 3]
else: else:
if len(tgt_tokens) > TGT_MAX_LENGTH: if len(tgt_sequence) > TGT_MAX_LENGTH:
tgt_sequence = tgt_sequence[block_size - len(src_sequence) - 3] tgt_sequence = tgt_sequence[block_size - len(src_sequence) - 3]
return src_sequence, tgt_sequence return src_sequence, tgt_sequence
def load_and_cache_examples(args, tokenizer): def load_and_cache_examples(args, tokenizer):
dataset = TextDataset(tokenizer, file_path=args.train_data_file) dataset = TextDataset(tokenizer, file_path=args.data_dir)
return dataset return dataset
...@@ -200,7 +200,7 @@ def main(): ...@@ -200,7 +200,7 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Required parameters # Required parameters
parser.add_argument("--train_data_file", parser.add_argument("--data_dir",
default=None, default=None,
type=str, type=str,
required=True, required=True,
......
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from .run_seq2seq_finetuning import process_story, _fit_to_block_size
class DataLoaderTest(unittest.TestCase):
def __init__(self, block_size=10):
self.block_size = block_size
def source_and_target_too_small(self):
""" When the sum of the lengths of the source and target sequences is
smaller than the block size (minus the number of special tokens), skip the example. """
src_seq = [1, 2, 3, 4]
tgt_seq = [5, 6]
self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None)
def source_and_target_fit_exactly(self):
""" When the sum of the lengths of the source and target sequences is
equal to the block size (minus the number of special tokens), return the
sequences unchanged. """
src_seq = [1, 2, 3, 4]
tgt_seq = [5, 6, 7]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == fitted_src)
self.assertListEqual(tgt_seq == fitted_tgt)
def source_too_big_target_ok(self):
src_seq = [1, 2, 3, 4, 5, 6]
tgt_seq = [1, 2]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == [1, 2, 3, 4, 5])
self.assertListEqual(tgt_seq == fitted_tgt)
def target_too_big_source_ok(self):
src_seq = [1, 2, 3, 4]
tgt_seq = [1, 2, 3, 4]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == src_seq)
self.assertListEqual(tgt_seq == [1, 2, 3])
def source_and_target_too_big(self):
src_seq = [1, 2, 3, 4, 5, 6, 7]
tgt_seq = [1, 2, 3, 4, 5, 6, 7]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == [1, 2, 3, 4, 5])
self.assertListEqual(tgt_seq == [1, 2])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment