Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8cd56e30
Commit
8cd56e30
authored
Oct 17, 2019
by
thomwolf
Browse files
fix data processing in script
parent
578d23e0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
34 deletions
+15
-34
examples/run_seq2seq_finetuning.py
examples/run_seq2seq_finetuning.py
+15
-34
No files found.
examples/run_seq2seq_finetuning.py
View file @
8cd56e30
...
...
@@ -58,12 +58,12 @@ class TextDataset(Dataset):
[2] https://github.com/abisee/cnn-dailymail/
"""
def
__init_
(
self
,
tokenizer
_src
,
tokenizer_tgt
,
data_dir
=
""
,
block_size
=
512
):
def
__init_
_
(
self
,
tokenizer
,
prefix
=
'train'
,
data_dir
=
""
,
block_size
=
512
):
assert
os
.
path
.
isdir
(
data_dir
)
# Load features that have already been computed if present
cached_features_file
=
os
.
path
.
join
(
data_dir
,
"cached_lm_{}_{}"
.
format
(
block_size
,
data_dir
)
data_dir
,
"cached_lm_{}_{}"
.
format
(
block_size
,
prefix
)
)
if
os
.
path
.
exists
(
cached_features_file
):
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
...
...
@@ -72,7 +72,7 @@ class TextDataset(Dataset):
return
logger
.
info
(
"Creating features from dataset at %s"
,
data_dir
)
self
.
examples
=
[]
datasets
=
[
"cnn"
,
"dailymail"
]
for
dataset
in
datasets
:
path_to_stories
=
os
.
path
.
join
(
data_dir
,
dataset
,
"stories"
)
...
...
@@ -91,21 +91,17 @@ class TextDataset(Dataset):
except
IndexError
:
# skip ill-formed stories
continue
story
=
tokenizer_src
.
convert_tokens_to_ids
(
tokenizer_src
.
tokenize
(
story
)
)
story
=
tokenizer
.
encode
(
story
)
story_seq
=
_fit_to_block_size
(
story
,
block_size
)
summary
=
tokenizer_tgt
.
convert_tokens_to_ids
(
tokenizer_tgt
.
tokenize
(
summary
)
)
summary
=
tokenizer
.
encode
(
summary
)
summary_seq
=
_fit_to_block_size
(
summary
,
block_size
)
self
.
examples
.
append
((
story_seq
,
summary_seq
))
logger
.
info
(
"Saving features into cache file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
"wb"
)
as
sink
:
pickle
.
dump
(
self
.
examples
,
sink
,
protocol
e
=
pickle
.
HIGHEST_PROTOCOL
)
pickle
.
dump
(
self
.
examples
,
sink
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
def
__len__
(
self
):
return
len
(
self
.
examples
)
...
...
@@ -169,11 +165,11 @@ def _fit_to_block_size(sequence, block_size):
if
len
(
sequence
)
>
block_size
:
return
sequence
[:
block_size
]
else
:
return
sequence
.
extend
([
-
1
]
*
[
block_size
-
len
(
sequence
)
]
)
return
sequence
.
extend
([
-
1
]
*
(
block_size
-
len
(
sequence
)
)
)
def
load_and_cache_examples
(
args
,
tokenizer
_src
,
tokenizer_tgt
):
dataset
=
TextDataset
(
tokenizer
_src
,
tokenizer_tgt
,
file_path
=
args
.
data_dir
)
def
load_and_cache_examples
(
args
,
tokenizer
):
dataset
=
TextDataset
(
tokenizer
,
data_dir
=
args
.
data_dir
)
return
dataset
...
...
@@ -293,29 +289,17 @@ def main():
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
)
parser
.
add_argument
(
"--
dec
ode
r
_name_or_path"
,
"--
m
ode
l
_name_or_path"
,
default
=
"bert-base-cased"
,
type
=
str
,
help
=
"The model checkpoint to initialize the decoder's weights with."
,
help
=
"The model checkpoint to initialize the
encoder and
decoder's weights with."
,
)
parser
.
add_argument
(
"--
dec
ode
r
_type"
,
"--
m
ode
l
_type"
,
default
=
"bert"
,
type
=
str
,
help
=
"The decoder architecture to be fine-tuned."
,
)
parser
.
add_argument
(
"--encoder_name_or_path"
,
default
=
"bert-base-cased"
,
type
=
str
,
help
=
"The model checkpoint to initialize the encoder's weights with."
,
)
parser
.
add_argument
(
"--encoder_type"
,
default
=
"bert"
,
type
=
str
,
help
=
"The encoder architecture to be fine-tuned."
,
)
parser
.
add_argument
(
"--learning_rate"
,
default
=
5e-5
,
...
...
@@ -346,7 +330,7 @@ def main():
)
args
=
parser
.
parse_args
()
if
args
.
encoder_type
!=
"bert"
or
args
.
decoder
_type
!=
"bert"
:
if
args
.
model
_type
!=
"bert"
:
raise
ValueError
(
"Only the BERT architecture is currently supported for seq2seq."
)
...
...
@@ -358,11 +342,8 @@ def main():
set_seed
(
args
)
# Load pretrained model and tokenizer
encoder_tokenizer_class
=
AutoTokenizer
.
from_pretrained
(
args
.
encoder_name_or_path
)
decoder_tokenizer_class
=
AutoTokenizer
.
from_pretrained
(
args
.
decoder_name_or_path
)
model
=
Model2Model
.
from_pretrained
(
args
.
encoder_name_or_path
,
args
.
decoder_name_or_path
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
Model2Model
.
from_pretrained
(
args
.
model_name_or_path
)
# model.to(device)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment