Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
22e1af68
Commit
22e1af68
authored
Oct 15, 2019
by
Rémi Louf
Browse files
truncation function is fully tested
parent
260ac7d9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
59 deletions
+74
-59
examples/run_seq2seq_finetuning.py
examples/run_seq2seq_finetuning.py
+58
-43
examples/run_seq2seq_finetuning_test.py
examples/run_seq2seq_finetuning_test.py
+16
-16
No files found.
examples/run_seq2seq_finetuning.py
View file @
22e1af68
...
@@ -41,7 +41,7 @@ import numpy as np
...
@@ -41,7 +41,7 @@ import numpy as np
import
torch
import
torch
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
transformers
import
BertConfig
,
Bert2Rnd
,
BertTokenizer
from
transformers
import
BertTokenizer
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -57,19 +57,23 @@ class TextDataset(Dataset):
...
@@ -57,19 +57,23 @@ class TextDataset(Dataset):
CNN/Daily News:
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The stories are stored in different files; the summary appears at the end of the story as
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
sentences that are prefixed by the special `@highlight` line. To process the
stored in different files; the summary appears at the end of the story as
data, untar both datasets in the same folder, and pass the path to this
sentences that are prefixed by the special `@highlight` line. To process
the data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2].
folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
[1] https://cs.nyu.edu/~kcho/
[2] https://github.com/abisee/cnn-dailymail/
[2] https://github.com/abisee/cnn-dailymail/
"""
"""
def
__init_
(
self
,
tokenizer
,
data_dir
=
''
,
block_size
=
512
):
def
__init_
(
self
,
tokenizer
,
data_dir
=
""
,
block_size
=
512
):
assert
os
.
path
.
isdir
(
data_dir
)
assert
os
.
path
.
isdir
(
data_dir
)
# Load features that have already been computed if present
# Load features that have already been computed if present
cached_features_file
=
os
.
path
.
join
(
directory
,
"cached_lm_{}_{}"
.
format
(
block_size
,
data_dir
))
cached_features_file
=
os
.
path
.
join
(
data_dir
,
"cached_lm_{}_{}"
.
format
(
block_size
,
data_dir
)
)
if
os
.
path
.
exists
(
cached_features_file
):
if
os
.
path
.
exists
(
cached_features_file
):
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
"rb"
)
as
source
:
with
open
(
cached_features_file
,
"rb"
)
as
source
:
...
@@ -78,7 +82,7 @@ class TextDataset(Dataset):
...
@@ -78,7 +82,7 @@ class TextDataset(Dataset):
logger
.
info
(
"Creating features from dataset at %s"
,
data_dir
)
logger
.
info
(
"Creating features from dataset at %s"
,
data_dir
)
datasets
=
[
'
cnn
'
,
'
dailymail
'
]
datasets
=
[
"
cnn
"
,
"
dailymail
"
]
for
dataset
in
datasets
:
for
dataset
in
datasets
:
path_to_stories
=
os
.
path
.
join
(
data_dir
,
dataset
,
"stories"
)
path_to_stories
=
os
.
path
.
join
(
data_dir
,
dataset
,
"stories"
)
assert
os
.
path
.
isdir
(
path_to_stories
)
assert
os
.
path
.
isdir
(
path_to_stories
)
...
@@ -99,7 +103,9 @@ class TextDataset(Dataset):
...
@@ -99,7 +103,9 @@ class TextDataset(Dataset):
story
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
story
))
story
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
story
))
summary
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
summary
))
summary
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
summary
))
story_seq
,
summary_seq
=
_fit_to_block_size
(
story
,
summary
,
block_size
)
story_seq
,
summary_seq
=
_fit_to_block_size
(
story
,
summary
,
block_size
)
example
=
tokenizer
.
add_special_token_sequence_pair
(
story_seq
,
summary_seq
)
example
=
tokenizer
.
add_special_token_sequence_pair
(
story_seq
,
summary_seq
)
self
.
examples
.
append
(
example
)
self
.
examples
.
append
(
example
)
logger
.
info
(
"Saving features into cache file %s"
,
cached_features_file
)
logger
.
info
(
"Saving features into cache file %s"
,
cached_features_file
)
...
@@ -117,7 +123,9 @@ def process_story(raw_story):
...
@@ -117,7 +123,9 @@ def process_story(raw_story):
""" Process the text contained in a story file.
""" Process the text contained in a story file.
Returns the story and the summary
Returns the story and the summary
"""
"""
file_lines
=
list
(
filter
(
lambda
x
:
len
(
x
)
!=
0
,
[
line
.
strip
()
for
line
in
raw_story
.
split
(
"
\n
"
)]))
file_lines
=
list
(
filter
(
lambda
x
:
len
(
x
)
!=
0
,
[
line
.
strip
()
for
line
in
raw_story
.
split
(
"
\n
"
)])
)
# for some unknown reason some lines miss a period, add it
# for some unknown reason some lines miss a period, add it
file_lines
=
[
_add_missing_period
(
line
)
for
line
in
file_lines
]
file_lines
=
[
_add_missing_period
(
line
)
for
line
in
file_lines
]
...
@@ -145,7 +153,7 @@ def process_story(raw_story):
...
@@ -145,7 +153,7 @@ def process_story(raw_story):
def
_add_missing_period
(
line
):
def
_add_missing_period
(
line
):
END_TOKENS
=
[
'.'
,
'!'
,
'?'
,
'
...
'
,
"'"
,
"`"
,
'"'
,
u
'
\u2019
'
,
u
'
\u2019
'
,
")"
]
END_TOKENS
=
[
"."
,
"!"
,
"?"
,
"
...
"
,
"'"
,
"`"
,
'"'
,
u
"
\u2019
"
,
u
"
\u2019
"
,
")"
]
if
line
.
startswith
(
"@highlight"
):
if
line
.
startswith
(
"@highlight"
):
return
line
return
line
if
line
[
-
1
]
in
END_TOKENS
:
if
line
[
-
1
]
in
END_TOKENS
:
...
@@ -154,34 +162,35 @@ def _add_missing_period(line):
...
@@ -154,34 +162,35 @@ def _add_missing_period(line):
def
_fit_to_block_size
(
src_sequence
,
tgt_sequence
,
block_size
):
def
_fit_to_block_size
(
src_sequence
,
tgt_sequence
,
block_size
):
"""
Concatenate
the s
equen
ce
s
and
adapt their
lengths to the block size.
"""
Adapt
the s
our
ce and
target sequences'
lengths to the block size.
Following [1] we truncate th
e source
and
target +
tokens sequences so they fit
If the concatenated sequenc
e
(
source
+
target +
3 special tokens) would be
i
n the block size
. If the concatenated sequence is longer than 512 we follow
longer tha
n the block size
we use the 75% / 25% rule followed in [1]. For a
the 75%/25% rule in [1]:
limit the source sequence's length to 384
and the
block size of 512 this means
limit
ing
the source sequence's length to 384
target sequence's length to 128.
and the
target sequence's length to 128.
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
"""
"""
SRC_MAX_LENGTH
=
int
(
0.75
*
block_size
)
-
2
# CLS and EOS token
SRC_MAX_LENGTH
=
int
(
0.75
*
block_size
)
-
2
# CLS and EOS token
TGT_MAX_LENGTH
=
block_size
-
SRC_MAX_LENGTH
-
1
# EOS token
TGT_MAX_LENGTH
=
block_size
-
(
SRC_MAX_LENGTH
+
2
)
-
1
# EOS token
#
w
e dump the examples that are too small to fit in the block size for the
#
W
e dump the examples that are too small to fit in the block size for the
# sake of simplicity. You can modify this by adding model-specific padding.
# sake of simplicity. You can modify this by adding model-specific padding.
if
len
(
src_sequence
)
+
len
(
src
_sequence
)
+
3
<
block_size
:
if
len
(
src_sequence
)
+
len
(
tgt
_sequence
)
+
3
<
block_size
:
return
None
return
None
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
if
len
(
src_sequence
)
>
SRC_MAX_LENGTH
:
if
len
(
src_sequence
)
>
SRC_MAX_LENGTH
:
if
len
(
tgt_sequence
)
>
TGT_MAX_LENGTH
:
if
len
(
tgt_sequence
)
>
TGT_MAX_LENGTH
:
src_sequence
=
src_sequence
[:
SRC_MAX_LENGTH
]
src_sequence
=
src_sequence
[:
SRC_MAX_LENGTH
]
tgt_sequence
=
tgt_sequence
[:
TGT_MAX_LENGTH
]
tgt_sequence
=
tgt_sequence
[:
TGT_MAX_LENGTH
]
else
:
else
:
src_sequence
=
src_sequence
[
block_size
-
len
(
tgt_sequence
)
-
3
]
remain_size
=
block_size
-
len
(
tgt_sequence
)
-
3
src_sequence
=
src_sequence
[:
remain_size
]
else
:
else
:
if
len
(
tgt_sequence
)
>
TGT_MAX_LENGTH
:
if
len
(
tgt_sequence
)
>
TGT_MAX_LENGTH
:
tgt_sequence
=
tgt_sequence
[
block_size
-
len
(
src_sequence
)
-
3
]
remain_size
=
block_size
-
len
(
src_sequence
)
-
3
tgt_sequence
=
tgt_sequence
[:
remain_size
]
return
src_sequence
,
tgt_sequence
return
src_sequence
,
tgt_sequence
...
@@ -200,44 +209,50 @@ def main():
...
@@ -200,44 +209,50 @@ def main():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
# Required parameters
# Required parameters
parser
.
add_argument
(
"--data_dir"
,
parser
.
add_argument
(
"--data_dir"
,
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"The input training data file (a text file)."
)
help
=
"The input training data file (a text file)."
,
parser
.
add_argument
(
"--output_dir"
,
)
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"The output directory where the model predictions and checkpoints will be written."
)
help
=
"The output directory where the model predictions and checkpoints will be written."
,
)
# Optional parameters
# Optional parameters
parser
.
add_argument
(
"--model_name_or_path"
,
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
"bert-base-cased"
,
default
=
"bert-base-cased"
,
type
=
str
,
type
=
str
,
help
=
"The model checkpoint for weights initialization."
)
help
=
"The model checkpoint for weights initialization."
,
)
parser
.
add_argument
(
"--seed"
,
default
=
42
,
type
=
int
)
parser
.
add_argument
(
"--seed"
,
default
=
42
,
type
=
int
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# Set up training device
# Set up training device
device
=
torch
.
device
(
"cpu"
)
#
device = torch.device("cpu")
# Set seed
# Set seed
set_seed
(
args
)
set_seed
(
args
)
# Load pretrained model and tokenizer
# Load pretrained model and tokenizer
config_class
,
model_class
,
tokenizer_class
=
BertConfig
,
Bert2Rnd
,
BertTokenizer
tokenizer_class
=
BertTokenizer
config
=
config_class
.
from_pretrained
(
args
.
model_name_or_path
)
#
config = config_class.from_pretrained(args.model_name_or_path)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
,
config
=
config
)
#
model = model_class.from_pretrained(args.model_name_or_path, config=config)
model
.
to
(
device
)
#
model.to(device)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
# Training
# Training
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
)
_
=
load_and_cache_examples
(
args
,
tokenizer
)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
)
#
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger
.
info
(
" global_step = %s, average loss = %s"
,
global_step
,
tr_loss
)
#
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/run_seq2seq_finetuning_test.py
View file @
22e1af68
...
@@ -14,50 +14,50 @@
...
@@ -14,50 +14,50 @@
# limitations under the License.
# limitations under the License.
import
unittest
import
unittest
from
.
run_seq2seq_finetuning
import
process_story
,
_fit_to_block_size
from
run_seq2seq_finetuning
import
_fit_to_block_size
class
DataLoaderTest
(
unittest
.
TestCase
):
class
DataLoaderTest
(
unittest
.
TestCase
):
def
__init__
(
self
,
block_size
=
10
):
def
setUp
(
self
):
self
.
block_size
=
block_size
self
.
block_size
=
10
def
source_and_target_too_small
(
self
):
def
test_
source_and_target_too_small
(
self
):
""" When the sum of the lengths of the source and target sequences is
""" When the sum of the lengths of the source and target sequences is
smaller than the block size (minus the number of special tokens), skip the example. """
smaller than the block size (minus the number of special tokens), skip the example. """
src_seq
=
[
1
,
2
,
3
,
4
]
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
5
,
6
]
tgt_seq
=
[
5
,
6
]
self
.
assertEqual
(
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
),
None
)
self
.
assertEqual
(
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
),
None
)
def
source_and_target_fit_exactly
(
self
):
def
test_
source_and_target_fit_exactly
(
self
):
""" When the sum of the lengths of the source and target sequences is
""" When the sum of the lengths of the source and target sequences is
equal to the block size (minus the number of special tokens), return the
equal to the block size (minus the number of special tokens), return the
sequences unchanged. """
sequences unchanged. """
src_seq
=
[
1
,
2
,
3
,
4
]
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
5
,
6
,
7
]
tgt_seq
=
[
5
,
6
,
7
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
src_seq
==
fitted_src
)
self
.
assertListEqual
(
src_seq
,
fitted_src
)
self
.
assertListEqual
(
tgt_seq
==
fitted_tgt
)
self
.
assertListEqual
(
tgt_seq
,
fitted_tgt
)
def
source_too_big_target_ok
(
self
):
def
test_
source_too_big_target_ok
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
]
src_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
]
tgt_seq
=
[
1
,
2
]
tgt_seq
=
[
1
,
2
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
src_seq
==
[
1
,
2
,
3
,
4
,
5
])
self
.
assertListEqual
(
fitted_src
,
[
1
,
2
,
3
,
4
,
5
])
self
.
assertListEqual
(
tgt_seq
==
fitted_tgt
)
self
.
assertListEqual
(
fitted_tgt
,
fitted_tgt
)
def
target_too_big_source_ok
(
self
):
def
test_
target_too_big_source_ok
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
]
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
1
,
2
,
3
,
4
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
src_seq
==
src_seq
)
self
.
assertListEqual
(
fitted_src
,
src_seq
)
self
.
assertListEqual
(
tgt_seq
==
[
1
,
2
,
3
])
self
.
assertListEqual
(
fitted_tgt
,
[
1
,
2
,
3
])
def
source_and_target_too_big
(
self
):
def
test_
source_and_target_too_big
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]
src_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]
tgt_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]
tgt_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
src_seq
==
[
1
,
2
,
3
,
4
,
5
])
self
.
assertListEqual
(
fitted_src
,
[
1
,
2
,
3
,
4
,
5
])
self
.
assertListEqual
(
tgt_seq
==
[
1
,
2
])
self
.
assertListEqual
(
fitted_tgt
,
[
1
,
2
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment