"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c6cf925ff8f06d05e0a6a281db51b91b27b91e69"
Commit 22e1af68 authored by Rémi Louf's avatar Rémi Louf
Browse files

truncation function is fully tested

parent 260ac7d9
...@@ -41,7 +41,7 @@ import numpy as np ...@@ -41,7 +41,7 @@ import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import BertConfig, Bert2Rnd, BertTokenizer from transformers import BertTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -57,19 +57,23 @@ class TextDataset(Dataset): ...@@ -57,19 +57,23 @@ class TextDataset(Dataset):
CNN/Daily News: CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The stories are stored in different files; the summary appears at the end of the story as The CNN/Daily News raw datasets are downloaded from [1]. The stories are
sentences that are prefixed by the special `@highlight` line. To process the stored in different files; the summary appears at the end of the story as
data, untar both datasets in the same folder, and pass the path to this sentences that are prefixed by the special `@highlight` line. To process
the data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2]. folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/ [1] https://cs.nyu.edu/~kcho/
[2] https://github.com/abisee/cnn-dailymail/ [2] https://github.com/abisee/cnn-dailymail/
""" """
def __init_(self, tokenizer, data_dir='', block_size=512):
def __init_(self, tokenizer, data_dir="", block_size=512):
assert os.path.isdir(data_dir) assert os.path.isdir(data_dir)
# Load features that have already been computed if present # Load features that have already been computed if present
cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, data_dir)) cached_features_file = os.path.join(
data_dir, "cached_lm_{}_{}".format(block_size, data_dir)
)
if os.path.exists(cached_features_file): if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
with open(cached_features_file, "rb") as source: with open(cached_features_file, "rb") as source:
...@@ -78,7 +82,7 @@ class TextDataset(Dataset): ...@@ -78,7 +82,7 @@ class TextDataset(Dataset):
logger.info("Creating features from dataset at %s", data_dir) logger.info("Creating features from dataset at %s", data_dir)
datasets = ['cnn', 'dailymail'] datasets = ["cnn", "dailymail"]
for dataset in datasets: for dataset in datasets:
path_to_stories = os.path.join(data_dir, dataset, "stories") path_to_stories = os.path.join(data_dir, dataset, "stories")
assert os.path.isdir(path_to_stories) assert os.path.isdir(path_to_stories)
...@@ -99,7 +103,9 @@ class TextDataset(Dataset): ...@@ -99,7 +103,9 @@ class TextDataset(Dataset):
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story)) story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary)) summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
story_seq, summary_seq = _fit_to_block_size(story, summary, block_size) story_seq, summary_seq = _fit_to_block_size(story, summary, block_size)
example = tokenizer.add_special_token_sequence_pair(story_seq, summary_seq) example = tokenizer.add_special_token_sequence_pair(
story_seq, summary_seq
)
self.examples.append(example) self.examples.append(example)
logger.info("Saving features into cache file %s", cached_features_file) logger.info("Saving features into cache file %s", cached_features_file)
...@@ -117,7 +123,9 @@ def process_story(raw_story): ...@@ -117,7 +123,9 @@ def process_story(raw_story):
""" Process the text contained in a story file. """ Process the text contained in a story file.
Returns the story and the summary Returns the story and the summary
""" """
file_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])) file_lines = list(
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
)
# for some unknown reason some lines miss a period, add it # for some unknown reason some lines miss a period, add it
file_lines = [_add_missing_period(line) for line in file_lines] file_lines = [_add_missing_period(line) for line in file_lines]
...@@ -145,7 +153,7 @@ def process_story(raw_story): ...@@ -145,7 +153,7 @@ def process_story(raw_story):
def _add_missing_period(line): def _add_missing_period(line):
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', u'\u2019', u'\u2019', ")"] END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
if line.startswith("@highlight"): if line.startswith("@highlight"):
return line return line
if line[-1] in END_TOKENS: if line[-1] in END_TOKENS:
...@@ -154,34 +162,35 @@ def _add_missing_period(line): ...@@ -154,34 +162,35 @@ def _add_missing_period(line):
def _fit_to_block_size(src_sequence, tgt_sequence, block_size): def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
""" Concatenate the sequences and adapt their lengths to the block size. """ Adapt the source and target sequences' lengths to the block size.
Following [1] we truncate the source and target + tokens sequences so they fit If the concatenated sequence (source + target + 3 special tokens) would be
in the block size. If the concatenated sequence is longer than 512 we follow longer than the block size we use the 75% / 25% rule followed in [1]. For a
the 75%/25% rule in [1]: limit the source sequence's length to 384 and the block size of 512 this means limiting the source sequence's length to 384
target sequence's length to 128. and the target sequence's length to 128.
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural [1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019). Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
""" """
SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token
TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token TGT_MAX_LENGTH = block_size - (SRC_MAX_LENGTH + 2) - 1 # EOS token
# we dump the examples that are too small to fit in the block size for the # We dump the examples that are too small to fit in the block size for the
# sake of simplicity. You can modify this by adding model-specific padding. # sake of simplicity. You can modify this by adding model-specific padding.
if len(src_sequence) + len(src_sequence) + 3 < block_size: if len(src_sequence) + len(tgt_sequence) + 3 < block_size:
return None return None
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
if len(src_sequence) > SRC_MAX_LENGTH: if len(src_sequence) > SRC_MAX_LENGTH:
if len(tgt_sequence) > TGT_MAX_LENGTH: if len(tgt_sequence) > TGT_MAX_LENGTH:
src_sequence = src_sequence[:SRC_MAX_LENGTH] src_sequence = src_sequence[:SRC_MAX_LENGTH]
tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH] tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH]
else: else:
src_sequence = src_sequence[block_size - len(tgt_sequence) - 3] remain_size = block_size - len(tgt_sequence) - 3
src_sequence = src_sequence[:remain_size]
else: else:
if len(tgt_sequence) > TGT_MAX_LENGTH: if len(tgt_sequence) > TGT_MAX_LENGTH:
tgt_sequence = tgt_sequence[block_size - len(src_sequence) - 3] remain_size = block_size - len(src_sequence) - 3
tgt_sequence = tgt_sequence[:remain_size]
return src_sequence, tgt_sequence return src_sequence, tgt_sequence
...@@ -200,44 +209,50 @@ def main(): ...@@ -200,44 +209,50 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Required parameters # Required parameters
parser.add_argument("--data_dir", parser.add_argument(
default=None, "--data_dir",
type=str, default=None,
required=True, type=str,
help="The input training data file (a text file).") required=True,
parser.add_argument("--output_dir", help="The input training data file (a text file).",
default=None, )
type=str, parser.add_argument(
required=True, "--output_dir",
help="The output directory where the model predictions and checkpoints will be written.") default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
)
# Optional parameters # Optional parameters
parser.add_argument("--model_name_or_path", parser.add_argument(
default="bert-base-cased", "--model_name_or_path",
type=str, default="bert-base-cased",
help="The model checkpoint for weights initialization.") type=str,
help="The model checkpoint for weights initialization.",
)
parser.add_argument("--seed", default=42, type=int) parser.add_argument("--seed", default=42, type=int)
args = parser.parse_args() args = parser.parse_args()
# Set up training device # Set up training device
device = torch.device("cpu") # device = torch.device("cpu")
# Set seed # Set seed
set_seed(args) set_seed(args)
# Load pretrained model and tokenizer # Load pretrained model and tokenizer
config_class, model_class, tokenizer_class = BertConfig, Bert2Rnd, BertTokenizer tokenizer_class = BertTokenizer
config = config_class.from_pretrained(args.model_name_or_path) # config = config_class.from_pretrained(args.model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path, config=config) # model = model_class.from_pretrained(args.model_name_or_path, config=config)
model.to(device) # model.to(device)
logger.info("Training/evaluation parameters %s", args) logger.info("Training/evaluation parameters %s", args)
# Training # Training
train_dataset = load_and_cache_examples(args, tokenizer) _ = load_and_cache_examples(args, tokenizer)
global_step, tr_loss = train(args, train_dataset, model, tokenizer) # global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) # logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -14,50 +14,50 @@ ...@@ -14,50 +14,50 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
from .run_seq2seq_finetuning import process_story, _fit_to_block_size from run_seq2seq_finetuning import _fit_to_block_size
class DataLoaderTest(unittest.TestCase): class DataLoaderTest(unittest.TestCase):
def __init__(self, block_size=10): def setUp(self):
self.block_size = block_size self.block_size = 10
def source_and_target_too_small(self): def test_source_and_target_too_small(self):
""" When the sum of the lengths of the source and target sequences is """ When the sum of the lengths of the source and target sequences is
smaller than the block size (minus the number of special tokens), skip the example. """ smaller than the block size (minus the number of special tokens), skip the example. """
src_seq = [1, 2, 3, 4] src_seq = [1, 2, 3, 4]
tgt_seq = [5, 6] tgt_seq = [5, 6]
self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None) self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None)
def source_and_target_fit_exactly(self): def test_source_and_target_fit_exactly(self):
""" When the sum of the lengths of the source and target sequences is """ When the sum of the lengths of the source and target sequences is
equal to the block size (minus the number of special tokens), return the equal to the block size (minus the number of special tokens), return the
sequences unchanged. """ sequences unchanged. """
src_seq = [1, 2, 3, 4] src_seq = [1, 2, 3, 4]
tgt_seq = [5, 6, 7] tgt_seq = [5, 6, 7]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == fitted_src) self.assertListEqual(src_seq, fitted_src)
self.assertListEqual(tgt_seq == fitted_tgt) self.assertListEqual(tgt_seq, fitted_tgt)
def source_too_big_target_ok(self): def test_source_too_big_target_ok(self):
src_seq = [1, 2, 3, 4, 5, 6] src_seq = [1, 2, 3, 4, 5, 6]
tgt_seq = [1, 2] tgt_seq = [1, 2]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == [1, 2, 3, 4, 5]) self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
self.assertListEqual(tgt_seq == fitted_tgt) self.assertListEqual(fitted_tgt, fitted_tgt)
def target_too_big_source_ok(self): def test_target_too_big_source_ok(self):
src_seq = [1, 2, 3, 4] src_seq = [1, 2, 3, 4]
tgt_seq = [1, 2, 3, 4] tgt_seq = [1, 2, 3, 4]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == src_seq) self.assertListEqual(fitted_src, src_seq)
self.assertListEqual(tgt_seq == [1, 2, 3]) self.assertListEqual(fitted_tgt, [1, 2, 3])
def source_and_target_too_big(self): def test_source_and_target_too_big(self):
src_seq = [1, 2, 3, 4, 5, 6, 7] src_seq = [1, 2, 3, 4, 5, 6, 7]
tgt_seq = [1, 2, 3, 4, 5, 6, 7] tgt_seq = [1, 2, 3, 4, 5, 6, 7]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == [1, 2, 3, 4, 5]) self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
self.assertListEqual(tgt_seq == [1, 2]) self.assertListEqual(fitted_tgt, [1, 2])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment