Commit 41279327 authored by Rémi Louf's avatar Rémi Louf
Browse files

delegate the padding with special tokens to the tokenizer

parent 447fffb2
......@@ -53,20 +53,14 @@ def set_seed(args):
class TextDataset(Dataset):
""" Abstracts a dataset used to train seq2seq models.
A seq2seq dataset consists of two files:
- The source file that contains the source sequences, one line per sequence;
- The target file contains the target sequences, one line per sequence.
The matching betwen source and target sequences is made on the basis of line numbers.
""" Abstracts the dataset used to train seq2seq models.
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. They consist in stories stored
in different files where the summary sentences are indicated by the special `@highlight` token.
To process the data, untar both datasets in the same folder, and path the path to this
folder as the "train_data_file" argument. The formatting code was inspired by [2].
To process the data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
[2] https://github.com/abisee/cnn-dailymail/
......@@ -82,9 +76,8 @@ class TextDataset(Dataset):
self.examples = pickle.load(source)
return
logger.info("Creating features from dataset at %s", directory)
logger.info("Creating features from dataset at %s", data_dir)
# we need to iterate over both the cnn and the dailymail dataset
datasets = ['cnn', 'dailymail']
for dataset in datasets:
path_to_stories = os.path.join(data_dir, dataset, "stories")
......@@ -102,9 +95,10 @@ class TextDataset(Dataset):
except IndexError:
continue
src_sequence = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
tgt_sequence = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
example = _truncate_and_concatenate(src_sequence, tgt_sequence, blocksize)
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
story_seq, summary_seq = _fit_to_block_size(story, summary, blocksize)
example = tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
self.examples.append(example)
logger.info("Saving features into cache file %s", cached_features_file)
......@@ -158,15 +152,13 @@ def _add_missing_period(line):
return line + " ."
def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size):
def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
""" Concatenate the sequences and adapt their lengths to the block size.
Following [1] we perform the following transformations:
- Add an [CLS] token at the beginning of the source sequence;
- Add an [EOS] token at the end of the source and target sequences;
- Concatenate the source and target + tokens sequence. If the concatenated sequence is
longer than 512 we follow the 75%/25% rule in [1]: limit the source sequence's length to 384
and the target sequence's length to 128.
Following [1] we truncate the source and target + tokens sequences so they fit
in the block size. If the concatenated sequence is longer than 512 we follow
the 75%/25% rule in [1]: limit the source sequence's length to 384 and the
target sequence's length to 128.
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
......@@ -176,22 +168,21 @@ def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size):
# we dump the examples that are too small to fit in the block size for the
# sake of simplicity. You can modify this by adding model-specific padding.
if len(src_tokens) + len(src_tokens) + 3 < block_size:
if len(src_sequence) + len(src_sequence) + 3 < block_size:
return None
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
if len(src_tokens) > SRC_MAX_LENGTH
if len(tgt_tokens) > TGT_MAX_LENGTH:
src_tokens = src_tokens[:SRC_MAX_LENGTH]
tgt_tokens = tgt_tokens[:TGT_MAX_LENGTH]
if len(src_sequence) > SRC_MAX_LENGTH
if len(tgt_sequence) > TGT_MAX_LENGTH:
src_sequence = src_sequence[:SRC_MAX_LENGTH]
tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH]
else:
src_tokens = src_tokens[block_size - len(tgt_tokens) - 3]
src_sequence = src_sequence[block_size - len(tgt_sequence) - 3]
else:
if len(tgt_tokens) > TGT_MAX_LENGTH:
tgt_tokens = tgt_tokens[block_size - len(src_tokens) - 3]
tgt_sequence = tgt_sequence[block_size - len(src_sequence) - 3]
# I add the special tokens manually, but this should be done by the tokenizer. That's the next step.
return ["[CLS]"] + src_tokens + ["[EOS]"] + tgt_tokens + ["[EOS]"]
return src_sequence, tgt_sequence
......@@ -250,4 +241,4 @@ def main():
if __name__ == "__main__":
main()
\ No newline at end of file
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment