Commit 22e1af68 authored by Rémi Louf's avatar Rémi Louf
Browse files

truncation function is fully tested

parent 260ac7d9
......@@ -41,7 +41,7 @@ import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import BertConfig, Bert2Rnd, BertTokenizer
from transformers import BertTokenizer
logger = logging.getLogger(__name__)
......@@ -57,19 +57,23 @@ class TextDataset(Dataset):
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The stories are stored in different files; the summary appears at the end of the story as
sentences that are prefixed by the special `@highlight` line. To process the
data, untar both datasets in the same folder, and pass the path to this
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
stored in different files; the summary appears at the end of the story as
sentences that are prefixed by the special `@highlight` line. To process
the data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
[2] https://github.com/abisee/cnn-dailymail/
"""
def __init_(self, tokenizer, data_dir='', block_size=512):
def __init_(self, tokenizer, data_dir="", block_size=512):
assert os.path.isdir(data_dir)
# Load features that have already been computed if present
cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, data_dir))
cached_features_file = os.path.join(
data_dir, "cached_lm_{}_{}".format(block_size, data_dir)
)
if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
with open(cached_features_file, "rb") as source:
......@@ -78,7 +82,7 @@ class TextDataset(Dataset):
logger.info("Creating features from dataset at %s", data_dir)
datasets = ['cnn', 'dailymail']
datasets = ["cnn", "dailymail"]
for dataset in datasets:
path_to_stories = os.path.join(data_dir, dataset, "stories")
assert os.path.isdir(path_to_stories)
......@@ -99,7 +103,9 @@ class TextDataset(Dataset):
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
story_seq, summary_seq = _fit_to_block_size(story, summary, block_size)
example = tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
example = tokenizer.add_special_token_sequence_pair(
story_seq, summary_seq
)
self.examples.append(example)
logger.info("Saving features into cache file %s", cached_features_file)
......@@ -117,7 +123,9 @@ def process_story(raw_story):
""" Process the text contained in a story file.
Returns the story and the summary
"""
file_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
file_lines = list(
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
)
# for some unknown reason some lines miss a period, add it
file_lines = [_add_missing_period(line) for line in file_lines]
......@@ -145,7 +153,7 @@ def process_story(raw_story):
def _add_missing_period(line):
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', u'\u2019', u'\u2019', ")"]
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
if line.startswith("@highlight"):
return line
if line[-1] in END_TOKENS:
......@@ -154,34 +162,35 @@ def _add_missing_period(line):
def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
""" Concatenate the sequences and adapt their lengths to the block size.
""" Adapt the source and target sequences' lengths to the block size.
Following [1] we truncate the source and target + tokens sequences so they fit
in the block size. If the concatenated sequence is longer than 512 we follow
the 75%/25% rule in [1]: limit the source sequence's length to 384 and the
target sequence's length to 128.
If the concatenated sequence (source + target + 3 special tokens) would be
longer than the block size we use the 75% / 25% rule followed in [1]. For a
block size of 512 this means limiting the source sequence's length to 384
and the target sequence's length to 128.
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
"""
SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token
TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token
TGT_MAX_LENGTH = block_size - (SRC_MAX_LENGTH + 2) - 1 # EOS token
# we dump the examples that are too small to fit in the block size for the
# We dump the examples that are too small to fit in the block size for the
# sake of simplicity. You can modify this by adding model-specific padding.
if len(src_sequence) + len(src_sequence) + 3 < block_size:
if len(src_sequence) + len(tgt_sequence) + 3 < block_size:
return None
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
if len(src_sequence) > SRC_MAX_LENGTH:
if len(tgt_sequence) > TGT_MAX_LENGTH:
src_sequence = src_sequence[:SRC_MAX_LENGTH]
tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH]
else:
src_sequence = src_sequence[block_size - len(tgt_sequence) - 3]
remain_size = block_size - len(tgt_sequence) - 3
src_sequence = src_sequence[:remain_size]
else:
if len(tgt_sequence) > TGT_MAX_LENGTH:
tgt_sequence = tgt_sequence[block_size - len(src_sequence) - 3]
remain_size = block_size - len(src_sequence) - 3
tgt_sequence = tgt_sequence[:remain_size]
return src_sequence, tgt_sequence
......@@ -200,44 +209,50 @@ def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--data_dir",
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input training data file (a text file).")
parser.add_argument("--output_dir",
help="The input training data file (a text file).",
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.")
help="The output directory where the model predictions and checkpoints will be written.",
)
# Optional parameters
parser.add_argument("--model_name_or_path",
parser.add_argument(
"--model_name_or_path",
default="bert-base-cased",
type=str,
help="The model checkpoint for weights initialization.")
help="The model checkpoint for weights initialization.",
)
parser.add_argument("--seed", default=42, type=int)
args = parser.parse_args()
# Set up training device
device = torch.device("cpu")
# device = torch.device("cpu")
# Set seed
set_seed(args)
# Load pretrained model and tokenizer
config_class, model_class, tokenizer_class = BertConfig, Bert2Rnd, BertTokenizer
config = config_class.from_pretrained(args.model_name_or_path)
tokenizer_class = BertTokenizer
# config = config_class.from_pretrained(args.model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path, config=config)
model.to(device)
# model = model_class.from_pretrained(args.model_name_or_path, config=config)
# model.to(device)
logger.info("Training/evaluation parameters %s", args)
# Training
train_dataset = load_and_cache_examples(args, tokenizer)
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
_ = load_and_cache_examples(args, tokenizer)
# global_step, tr_loss = train(args, train_dataset, model, tokenizer)
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
if __name__ == "__main__":
......
......@@ -14,50 +14,50 @@
# limitations under the License.
import unittest
from .run_seq2seq_finetuning import process_story, _fit_to_block_size
from run_seq2seq_finetuning import _fit_to_block_size
class DataLoaderTest(unittest.TestCase):
def __init__(self, block_size=10):
self.block_size = block_size
def setUp(self):
self.block_size = 10
def source_and_target_too_small(self):
def test_source_and_target_too_small(self):
""" When the sum of the lengths of the source and target sequences is
smaller than the block size (minus the number of special tokens), skip the example. """
src_seq = [1, 2, 3, 4]
tgt_seq = [5, 6]
self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None)
def source_and_target_fit_exactly(self):
def test_source_and_target_fit_exactly(self):
""" When the sum of the lengths of the source and target sequences is
equal to the block size (minus the number of special tokens), return the
sequences unchanged. """
src_seq = [1, 2, 3, 4]
tgt_seq = [5, 6, 7]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == fitted_src)
self.assertListEqual(tgt_seq == fitted_tgt)
self.assertListEqual(src_seq, fitted_src)
self.assertListEqual(tgt_seq, fitted_tgt)
def source_too_big_target_ok(self):
def test_source_too_big_target_ok(self):
src_seq = [1, 2, 3, 4, 5, 6]
tgt_seq = [1, 2]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == [1, 2, 3, 4, 5])
self.assertListEqual(tgt_seq == fitted_tgt)
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
self.assertListEqual(fitted_tgt, fitted_tgt)
def target_too_big_source_ok(self):
def test_target_too_big_source_ok(self):
src_seq = [1, 2, 3, 4]
tgt_seq = [1, 2, 3, 4]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == src_seq)
self.assertListEqual(tgt_seq == [1, 2, 3])
self.assertListEqual(fitted_src, src_seq)
self.assertListEqual(fitted_tgt, [1, 2, 3])
def source_and_target_too_big(self):
def test_source_and_target_too_big(self):
src_seq = [1, 2, 3, 4, 5, 6, 7]
tgt_seq = [1, 2, 3, 4, 5, 6, 7]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == [1, 2, 3, 4, 5])
self.assertListEqual(tgt_seq == [1, 2])
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
self.assertListEqual(fitted_tgt, [1, 2])
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment