Commit 260ac7d9 authored by Rémi Louf's avatar Rémi Louf
Browse files

wip commit, switching computers

parent fe25eefc
......@@ -31,7 +31,7 @@ Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
"""
import argparse
import dequeue
from collections import deque
import logging
import pickle
import random
......@@ -57,9 +57,9 @@ class TextDataset(Dataset):
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. They consist in stories stored
in different files where the summary sentences are indicated by the special `@highlight` token.
To process the data, untar both datasets in the same folder, and pass the path to this
The CNN/Daily News raw datasets are downloaded from [1]. The stories are stored in different files; the summary appears at the end of the story as
sentences that are prefixed by the special `@highlight` line. To process the
data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
......@@ -69,7 +69,7 @@ class TextDataset(Dataset):
assert os.path.isdir(data_dir)
# Load features that have already been computed if present
cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, data_dir)
cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, data_dir))
if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
with open(cached_features_file, "rb") as source:
......@@ -86,18 +86,19 @@ class TextDataset(Dataset):
stories_files = os.listdir(path_to_stories)
for story_file in stories_files:
path_to_story = os.path.join(path_to_stories, "story_file")
if !os.path.isfile(path_to_story):
if not os.path.isfile(path_to_story):
continue
with open(path_to_story, encoding="utf-8") as source:
try:
story, summary = process_story(source)
raw_story = source.read()
story, summary = process_story(raw_story)
except IndexError:
continue
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
story_seq, summary_seq = _fit_to_block_size(story, summary, blocksize)
story_seq, summary_seq = _fit_to_block_size(story, summary, block_size)
example = tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
self.examples.append(example)
......@@ -108,22 +109,22 @@ class TextDataset(Dataset):
def __len__(self):
return len(self.examples)
def __getitem__(self):
def __getitem__(self, items):
return torch.tensor(self.examples[items])
def process_story(story_file):
def process_story(raw_story):
""" Process the text contained in a story file.
Returns the story and the summary
"""
file_lines = list(filter(lambda x: len(x)!=0, [line.strip() for lines in story_file]))
file_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
# for some unknown reason some lines miss a period, add it
file_lines = [_add_missing_period(line) for line in file_lines]
# gather article lines
story_lines = []
lines = dequeue(file_lines)
lines = deque(file_lines)
while True:
try:
element = lines.popleft()
......@@ -134,7 +135,7 @@ def process_story(story_file):
raise ie
# gather summary lines
highlights_lines = list(filter(lambda t: !t.startswith("@highlight"), lines))
highlights_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
# join the lines
story = " ".join(story_lines)
......@@ -145,7 +146,7 @@ def process_story(story_file):
def _add_missing_period(line):
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', u'\u2019', u'\u2019', ")"]
if line == "@highlight":
if line.startswith("@highlight"):
return line
if line[-1] in END_TOKENS:
return line
......@@ -163,8 +164,8 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
"""
SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token
TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token
SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token
TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token
# we dump the examples that are too small to fit in the block size for the
# sake of simplicity. You can modify this by adding model-specific padding.
......@@ -172,22 +173,21 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
return None
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
if len(src_sequence) > SRC_MAX_LENGTH
if len(src_sequence) > SRC_MAX_LENGTH:
if len(tgt_sequence) > TGT_MAX_LENGTH:
src_sequence = src_sequence[:SRC_MAX_LENGTH]
tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH]
else:
src_sequence = src_sequence[block_size - len(tgt_sequence) - 3]
else:
if len(tgt_tokens) > TGT_MAX_LENGTH:
if len(tgt_sequence) > TGT_MAX_LENGTH:
tgt_sequence = tgt_sequence[block_size - len(src_sequence) - 3]
return src_sequence, tgt_sequence
def load_and_cache_examples(args, tokenizer):
dataset = TextDataset(tokenizer, file_path=args.train_data_file)
dataset = TextDataset(tokenizer, file_path=args.data_dir)
return dataset
......@@ -200,7 +200,7 @@ def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--train_data_file",
parser.add_argument("--data_dir",
default=None,
type=str,
required=True,
......
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from .run_seq2seq_finetuning import process_story, _fit_to_block_size
class DataLoaderTest(unittest.TestCase):
def __init__(self, block_size=10):
self.block_size = block_size
def source_and_target_too_small(self):
""" When the sum of the lengths of the source and target sequences is
smaller than the block size (minus the number of special tokens), skip the example. """
src_seq = [1, 2, 3, 4]
tgt_seq = [5, 6]
self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None)
def source_and_target_fit_exactly(self):
""" When the sum of the lengths of the source and target sequences is
equal to the block size (minus the number of special tokens), return the
sequences unchanged. """
src_seq = [1, 2, 3, 4]
tgt_seq = [5, 6, 7]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == fitted_src)
self.assertListEqual(tgt_seq == fitted_tgt)
def source_too_big_target_ok(self):
src_seq = [1, 2, 3, 4, 5, 6]
tgt_seq = [1, 2]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == [1, 2, 3, 4, 5])
self.assertListEqual(tgt_seq == fitted_tgt)
def target_too_big_source_ok(self):
src_seq = [1, 2, 3, 4]
tgt_seq = [1, 2, 3, 4]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == src_seq)
self.assertListEqual(tgt_seq == [1, 2, 3])
def source_and_target_too_big(self):
src_seq = [1, 2, 3, 4, 5, 6, 7]
tgt_seq = [1, 2, 3, 4, 5, 6, 7]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == [1, 2, 3, 4, 5])
self.assertListEqual(tgt_seq == [1, 2])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment