Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c1bc709c
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a32d85f0d405be53117b96075eef2875d2185892"
Commit
c1bc709c
authored
Oct 17, 2019
by
Rémi Louf
Browse files
correct the truncation and padding of dataset
parent
87d60b6e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
38 deletions
+10
-38
examples/run_seq2seq_finetuning.py
examples/run_seq2seq_finetuning.py
+10
-38
No files found.
examples/run_seq2seq_finetuning.py
View file @
c1bc709c
...
...
@@ -104,9 +104,11 @@ class TextDataset(Dataset):
except
IndexError
:
# skip ill-formed stories
continue
story
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
story
))
summary
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
summary
))
story_seq
,
summary_seq
=
_fit_to_block_size
(
story
,
summary
,
block_size
)
summary_seq
=
_fit_to_block_size
(
summary
,
block_size
)
story
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
story
))
story_seq
=
_fit_to_block_size
(
story
,
block_size
)
self
.
examples
.
append
(
tokenizer
.
add_special_token_sequence_pair
(
story_seq
,
summary_seq
)
...
...
@@ -170,45 +172,15 @@ def _add_missing_period(line):
return
line
+
"."
def
_fit_to_block_size
(
src_sequence
,
tgt_
sequence
,
block_size
):
def
_fit_to_block_size
(
sequence
,
block_size
):
""" Adapt the source and target sequences' lengths to the block size.
If the concatenated sequence (source + target + 3 special tokens) would be
longer than the block size we use the 75% / 25% rule followed in [1]. For a
block size of 512 this means limiting the source sequence's length to 384
and the target sequence's length to 128.
Attributes:
src_sequence (list): a list of ids that maps to the tokens of the
source sequence.
tgt_sequence (list): a list of ids that maps to the tokens of the
target sequence.
block_size (int): the model's block size.
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
If the sequence is shorter than the block size we pad it with -1 ids
which correspond to padding tokens.
"""
SRC_MAX_LENGTH
=
int
(
0.75
*
block_size
)
-
2
# CLS and EOS token
TGT_MAX_LENGTH
=
block_size
-
(
SRC_MAX_LENGTH
+
2
)
-
1
# EOS token
# We dump the examples that are too small to fit in the block size for the
# sake of simplicity. You can modify this by adding model-specific padding.
if
len
(
src_sequence
)
+
len
(
tgt_sequence
)
+
3
<
block_size
:
return
None
if
len
(
src_sequence
)
>
SRC_MAX_LENGTH
:
if
len
(
tgt_sequence
)
>
TGT_MAX_LENGTH
:
src_sequence
=
src_sequence
[:
SRC_MAX_LENGTH
]
tgt_sequence
=
tgt_sequence
[:
TGT_MAX_LENGTH
]
else
:
remain_size
=
block_size
-
len
(
tgt_sequence
)
-
3
src_sequence
=
src_sequence
[:
remain_size
]
if
len
(
sequence
)
>
block_size
:
return
sequence
[:
block_size
]
else
:
if
len
(
tgt_sequence
)
>
TGT_MAX_LENGTH
:
remain_size
=
block_size
-
len
(
src_sequence
)
-
3
tgt_sequence
=
tgt_sequence
[:
remain_size
]
return
src_sequence
,
tgt_sequence
return
sequence
.
extend
([
-
1
]
*
[
block_size
-
len
(
sequence
)])
def
load_and_cache_examples
(
args
,
tokenizer
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment