Commit 41279327 authored by Rémi Louf's avatar Rémi Louf
Browse files

delegate the padding with special tokens to the tokenizer

parent 447fffb2
...@@ -53,20 +53,14 @@ def set_seed(args): ...@@ -53,20 +53,14 @@ def set_seed(args):
class TextDataset(Dataset): class TextDataset(Dataset):
""" Abstracts a dataset used to train seq2seq models. """ Abstracts the dataset used to train seq2seq models.
A seq2seq dataset consists of two files:
- The source file that contains the source sequences, one line per sequence;
- The target file contains the target sequences, one line per sequence.
The matching betwen source and target sequences is made on the basis of line numbers.
CNN/Daily News: CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. They consist in stories stored The CNN/Daily News raw datasets are downloaded from [1]. They consist in stories stored
in different files where the summary sentences are indicated by the special `@highlight` token. in different files where the summary sentences are indicated by the special `@highlight` token.
To process the data, untar both datasets in the same folder, and path the path to this To process the data, untar both datasets in the same folder, and pass the path to this
folder as the "train_data_file" argument. The formatting code was inspired by [2]. folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/ [1] https://cs.nyu.edu/~kcho/
[2] https://github.com/abisee/cnn-dailymail/ [2] https://github.com/abisee/cnn-dailymail/
...@@ -82,9 +76,8 @@ class TextDataset(Dataset): ...@@ -82,9 +76,8 @@ class TextDataset(Dataset):
self.examples = pickle.load(source) self.examples = pickle.load(source)
return return
logger.info("Creating features from dataset at %s", directory) logger.info("Creating features from dataset at %s", data_dir)
# we need to iterate over both the cnn and the dailymail dataset
datasets = ['cnn', 'dailymail'] datasets = ['cnn', 'dailymail']
for dataset in datasets: for dataset in datasets:
path_to_stories = os.path.join(data_dir, dataset, "stories") path_to_stories = os.path.join(data_dir, dataset, "stories")
...@@ -102,9 +95,10 @@ class TextDataset(Dataset): ...@@ -102,9 +95,10 @@ class TextDataset(Dataset):
except IndexError: except IndexError:
continue continue
src_sequence = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story)) story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
tgt_sequence = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary)) summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
example = _truncate_and_concatenate(src_sequence, tgt_sequence, blocksize) story_seq, summary_seq = _fit_to_block_size(story, summary, blocksize)
example = tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
self.examples.append(example) self.examples.append(example)
logger.info("Saving features into cache file %s", cached_features_file) logger.info("Saving features into cache file %s", cached_features_file)
...@@ -158,15 +152,13 @@ def _add_missing_period(line): ...@@ -158,15 +152,13 @@ def _add_missing_period(line):
return line + " ." return line + " ."
def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size): def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
""" Concatenate the sequences and adapt their lengths to the block size. """ Concatenate the sequences and adapt their lengths to the block size.
Following [1] we perform the following transformations: Following [1] we truncate the source and target + tokens sequences so they fit
- Add an [CLS] token at the beginning of the source sequence; in the block size. If the concatenated sequence is longer than 512 we follow
- Add an [EOS] token at the end of the source and target sequences; the 75%/25% rule in [1]: limit the source sequence's length to 384 and the
- Concatenate the source and target + tokens sequence. If the concatenated sequence is target sequence's length to 128.
longer than 512 we follow the 75%/25% rule in [1]: limit the source sequence's length to 384
and the target sequence's length to 128.
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural [1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019). Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
...@@ -176,22 +168,21 @@ def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size): ...@@ -176,22 +168,21 @@ def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size):
# we dump the examples that are too small to fit in the block size for the # we dump the examples that are too small to fit in the block size for the
# sake of simplicity. You can modify this by adding model-specific padding. # sake of simplicity. You can modify this by adding model-specific padding.
if len(src_tokens) + len(src_tokens) + 3 < block_size: if len(src_sequence) + len(src_sequence) + 3 < block_size:
return None return None
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now. # the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
if len(src_tokens) > SRC_MAX_LENGTH if len(src_sequence) > SRC_MAX_LENGTH
if len(tgt_tokens) > TGT_MAX_LENGTH: if len(tgt_sequence) > TGT_MAX_LENGTH:
src_tokens = src_tokens[:SRC_MAX_LENGTH] src_sequence = src_sequence[:SRC_MAX_LENGTH]
tgt_tokens = tgt_tokens[:TGT_MAX_LENGTH] tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH]
else: else:
src_tokens = src_tokens[block_size - len(tgt_tokens) - 3] src_sequence = src_sequence[block_size - len(tgt_sequence) - 3]
else: else:
if len(tgt_tokens) > TGT_MAX_LENGTH: if len(tgt_tokens) > TGT_MAX_LENGTH:
tgt_tokens = tgt_tokens[block_size - len(src_tokens) - 3] tgt_sequence = tgt_sequence[block_size - len(src_sequence) - 3]
# I add the special tokens manually, but this should be done by the tokenizer. That's the next step. return src_sequence, tgt_sequence
return ["[CLS]"] + src_tokens + ["[EOS]"] + tgt_tokens + ["[EOS]"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment