Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
41279327
Commit
41279327
authored
Oct 14, 2019
by
Rémi Louf
Browse files
delegate the padding with special tokens to the tokenizer
parent
447fffb2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
31 deletions
+22
-31
examples/run_seq2seq_finetuning.py
examples/run_seq2seq_finetuning.py
+22
-31
No files found.
examples/run_seq2seq_finetuning.py
View file @
41279327
...
@@ -53,20 +53,14 @@ def set_seed(args):
...
@@ -53,20 +53,14 @@ def set_seed(args):
class
TextDataset
(
Dataset
):
class
TextDataset
(
Dataset
):
""" Abstracts a dataset used to train seq2seq models.
""" Abstracts the dataset used to train seq2seq models.
A seq2seq dataset consists of two files:
- The source file that contains the source sequences, one line per sequence;
- The target file contains the target sequences, one line per sequence.
The matching betwen source and target sequences is made on the basis of line numbers.
CNN/Daily News:
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. They consist in stories stored
The CNN/Daily News raw datasets are downloaded from [1]. They consist in stories stored
in different files where the summary sentences are indicated by the special `@highlight` token.
in different files where the summary sentences are indicated by the special `@highlight` token.
To process the data, untar both datasets in the same folder, and pa
th
the path to this
To process the data, untar both datasets in the same folder, and pa
ss
the path to this
folder as the "
train_data_file"
argument. The formatting code was inspired by [2].
folder as the "
data_dir
argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
[1] https://cs.nyu.edu/~kcho/
[2] https://github.com/abisee/cnn-dailymail/
[2] https://github.com/abisee/cnn-dailymail/
...
@@ -82,9 +76,8 @@ class TextDataset(Dataset):
...
@@ -82,9 +76,8 @@ class TextDataset(Dataset):
self
.
examples
=
pickle
.
load
(
source
)
self
.
examples
=
pickle
.
load
(
source
)
return
return
logger
.
info
(
"Creating features from dataset at %s"
,
d
irectory
)
logger
.
info
(
"Creating features from dataset at %s"
,
d
ata_dir
)
# we need to iterate over both the cnn and the dailymail dataset
datasets
=
[
'cnn'
,
'dailymail'
]
datasets
=
[
'cnn'
,
'dailymail'
]
for
dataset
in
datasets
:
for
dataset
in
datasets
:
path_to_stories
=
os
.
path
.
join
(
data_dir
,
dataset
,
"stories"
)
path_to_stories
=
os
.
path
.
join
(
data_dir
,
dataset
,
"stories"
)
...
@@ -102,9 +95,10 @@ class TextDataset(Dataset):
...
@@ -102,9 +95,10 @@ class TextDataset(Dataset):
except
IndexError
:
except
IndexError
:
continue
continue
src_sequence
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
story
))
story
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
story
))
tgt_sequence
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
summary
))
summary
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
summary
))
example
=
_truncate_and_concatenate
(
src_sequence
,
tgt_sequence
,
blocksize
)
story_seq
,
summary_seq
=
_fit_to_block_size
(
story
,
summary
,
blocksize
)
example
=
tokenizer
.
add_special_token_sequence_pair
(
story_seq
,
summary_seq
)
self
.
examples
.
append
(
example
)
self
.
examples
.
append
(
example
)
logger
.
info
(
"Saving features into cache file %s"
,
cached_features_file
)
logger
.
info
(
"Saving features into cache file %s"
,
cached_features_file
)
...
@@ -158,15 +152,13 @@ def _add_missing_period(line):
...
@@ -158,15 +152,13 @@ def _add_missing_period(line):
return
line
+
" ."
return
line
+
" ."
def
_
truncate_and_concatenat
e
(
src_sequence
,
tgt_sequence
,
block_size
):
def
_
fit_to_block_siz
e
(
src_sequence
,
tgt_sequence
,
block_size
):
""" Concatenate the sequences and adapt their lengths to the block size.
""" Concatenate the sequences and adapt their lengths to the block size.
Following [1] we perform the following transformations:
Following [1] we truncate the source and target + tokens sequences so they fit
- Add an [CLS] token at the beginning of the source sequence;
in the block size. If the concatenated sequence is longer than 512 we follow
- Add an [EOS] token at the end of the source and target sequences;
the 75%/25% rule in [1]: limit the source sequence's length to 384 and the
- Concatenate the source and target + tokens sequence. If the concatenated sequence is
target sequence's length to 128.
longer than 512 we follow the 75%/25% rule in [1]: limit the source sequence's length to 384
and the target sequence's length to 128.
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
...
@@ -176,22 +168,21 @@ def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size):
...
@@ -176,22 +168,21 @@ def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size):
# we dump the examples that are too small to fit in the block size for the
# we dump the examples that are too small to fit in the block size for the
# sake of simplicity. You can modify this by adding model-specific padding.
# sake of simplicity. You can modify this by adding model-specific padding.
if
len
(
src_
tokens
)
+
len
(
src_
tokens
)
+
3
<
block_size
:
if
len
(
src_
sequence
)
+
len
(
src_
sequence
)
+
3
<
block_size
:
return
None
return
None
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
if
len
(
src_
tokens
)
>
SRC_MAX_LENGTH
if
len
(
src_
sequence
)
>
SRC_MAX_LENGTH
if
len
(
tgt_
tokens
)
>
TGT_MAX_LENGTH
:
if
len
(
tgt_
sequence
)
>
TGT_MAX_LENGTH
:
src_
tokens
=
src_tokens
[:
SRC_MAX_LENGTH
]
src_
sequence
=
src_sequence
[:
SRC_MAX_LENGTH
]
tgt_
tokens
=
tgt_tokens
[:
TGT_MAX_LENGTH
]
tgt_
sequence
=
tgt_sequence
[:
TGT_MAX_LENGTH
]
else
:
else
:
src_
tokens
=
src_tokens
[
block_size
-
len
(
tgt_
tokens
)
-
3
]
src_
sequence
=
src_sequence
[
block_size
-
len
(
tgt_
sequence
)
-
3
]
else
:
else
:
if
len
(
tgt_tokens
)
>
TGT_MAX_LENGTH
:
if
len
(
tgt_tokens
)
>
TGT_MAX_LENGTH
:
tgt_
tokens
=
tgt_tokens
[
block_size
-
len
(
src_
tokens
)
-
3
]
tgt_
sequence
=
tgt_sequence
[
block_size
-
len
(
src_
sequence
)
-
3
]
# I add the special tokens manually, but this should be done by the tokenizer. That's the next step.
return
src_sequence
,
tgt_sequence
return
[
"[CLS]"
]
+
src_tokens
+
[
"[EOS]"
]
+
tgt_tokens
+
[
"[EOS]"
]
...
@@ -250,4 +241,4 @@ def main():
...
@@ -250,4 +241,4 @@ def main():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment