Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8cd56e30
Commit
8cd56e30
authored
Oct 17, 2019
by
thomwolf
Browse files
fix data processing in script
parent
578d23e0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
34 deletions
+15
-34
examples/run_seq2seq_finetuning.py
examples/run_seq2seq_finetuning.py
+15
-34
No files found.
examples/run_seq2seq_finetuning.py
View file @
8cd56e30
...
...
@@ -58,12 +58,12 @@ class TextDataset(Dataset):
[2] https://github.com/abisee/cnn-dailymail/
"""
def
__init_
(
self
,
tokenizer
_src
,
tokenizer_tgt
,
data_dir
=
""
,
block_size
=
512
):
def
__init_
_
(
self
,
tokenizer
,
prefix
=
'train'
,
data_dir
=
""
,
block_size
=
512
):
assert
os
.
path
.
isdir
(
data_dir
)
# Load features that have already been computed if present
cached_features_file
=
os
.
path
.
join
(
data_dir
,
"cached_lm_{}_{}"
.
format
(
block_size
,
data_dir
)
data_dir
,
"cached_lm_{}_{}"
.
format
(
block_size
,
prefix
)
)
if
os
.
path
.
exists
(
cached_features_file
):
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
...
...
@@ -72,7 +72,7 @@ class TextDataset(Dataset):
return
logger
.
info
(
"Creating features from dataset at %s"
,
data_dir
)
self
.
examples
=
[]
datasets
=
[
"cnn"
,
"dailymail"
]
for
dataset
in
datasets
:
path_to_stories
=
os
.
path
.
join
(
data_dir
,
dataset
,
"stories"
)
...
...
@@ -91,21 +91,17 @@ class TextDataset(Dataset):
except
IndexError
:
# skip ill-formed stories
continue
story
=
tokenizer_src
.
convert_tokens_to_ids
(
tokenizer_src
.
tokenize
(
story
)
)
story
=
tokenizer
.
encode
(
story
)
story_seq
=
_fit_to_block_size
(
story
,
block_size
)
summary
=
tokenizer_tgt
.
convert_tokens_to_ids
(
tokenizer_tgt
.
tokenize
(
summary
)
)
summary
=
tokenizer
.
encode
(
summary
)
summary_seq
=
_fit_to_block_size
(
summary
,
block_size
)
self
.
examples
.
append
((
story_seq
,
summary_seq
))
logger
.
info
(
"Saving features into cache file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
"wb"
)
as
sink
:
pickle
.
dump
(
self
.
examples
,
sink
,
protocol
e
=
pickle
.
HIGHEST_PROTOCOL
)
pickle
.
dump
(
self
.
examples
,
sink
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
def
__len__
(
self
):
return
len
(
self
.
examples
)
...
...
@@ -169,11 +165,11 @@ def _fit_to_block_size(sequence, block_size):
if
len
(
sequence
)
>
block_size
:
return
sequence
[:
block_size
]
else
:
return
sequence
.
extend
([
-
1
]
*
[
block_size
-
len
(
sequence
)
]
)
return
sequence
.
extend
([
-
1
]
*
(
block_size
-
len
(
sequence
)
)
)
def
load_and_cache_examples
(
args
,
tokenizer
_src
,
tokenizer_tgt
):
dataset
=
TextDataset
(
tokenizer
_src
,
tokenizer_tgt
,
file_path
=
args
.
data_dir
)
def
load_and_cache_examples
(
args
,
tokenizer
):
dataset
=
TextDataset
(
tokenizer
,
data_dir
=
args
.
data_dir
)
return
dataset
...
...
@@ -293,29 +289,17 @@ def main():
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
)
parser
.
add_argument
(
"--
dec
ode
r
_name_or_path"
,
"--
m
ode
l
_name_or_path"
,
default
=
"bert-base-cased"
,
type
=
str
,
help
=
"The model checkpoint to initialize the decoder's weights with."
,
help
=
"The model checkpoint to initialize the
encoder and
decoder's weights with."
,
)
parser
.
add_argument
(
"--
dec
ode
r
_type"
,
"--
m
ode
l
_type"
,
default
=
"bert"
,
type
=
str
,
help
=
"The decoder architecture to be fine-tuned."
,
)
parser
.
add_argument
(
"--encoder_name_or_path"
,
default
=
"bert-base-cased"
,
type
=
str
,
help
=
"The model checkpoint to initialize the encoder's weights with."
,
)
parser
.
add_argument
(
"--encoder_type"
,
default
=
"bert"
,
type
=
str
,
help
=
"The encoder architecture to be fine-tuned."
,
)
parser
.
add_argument
(
"--learning_rate"
,
default
=
5e-5
,
...
...
@@ -346,7 +330,7 @@ def main():
)
args
=
parser
.
parse_args
()
if
args
.
encoder_type
!=
"bert"
or
args
.
decoder
_type
!=
"bert"
:
if
args
.
model
_type
!=
"bert"
:
raise
ValueError
(
"Only the BERT architecture is currently supported for seq2seq."
)
...
...
@@ -358,11 +342,8 @@ def main():
set_seed
(
args
)
# Load pretrained model and tokenizer
encoder_tokenizer_class
=
AutoTokenizer
.
from_pretrained
(
args
.
encoder_name_or_path
)
decoder_tokenizer_class
=
AutoTokenizer
.
from_pretrained
(
args
.
decoder_name_or_path
)
model
=
Model2Model
.
from_pretrained
(
args
.
encoder_name_or_path
,
args
.
decoder_name_or_path
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
Model2Model
.
from_pretrained
(
args
.
model_name_or_path
)
# model.to(device)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment