Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
47a06d88
Commit
47a06d88
authored
Oct 17, 2019
by
Rémi Louf
Browse files
use two different tokenizers for storyand summary
parent
bfb9b540
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
18 deletions
+36
-18
examples/run_seq2seq_finetuning.py
examples/run_seq2seq_finetuning.py
+36
-18
No files found.
examples/run_seq2seq_finetuning.py
View file @
47a06d88
...
...
@@ -26,7 +26,7 @@ import numpy as np
import
torch
from
torch.utils.data
import
Dataset
from
transformers
import
Bert
Tokenizer
from
transformers
import
Auto
Tokenizer
,
Model2Model
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -57,7 +57,7 @@ class TextDataset(Dataset):
[2] https://github.com/abisee/cnn-dailymail/
"""
def
__init_
(
self
,
tokenizer
,
data_dir
=
""
,
block_size
=
512
):
def
__init_
(
self
,
tokenizer
_src
,
tokenizer_tgt
,
data_dir
=
""
,
block_size
=
512
):
assert
os
.
path
.
isdir
(
data_dir
)
# Load features that have already been computed if present
...
...
@@ -90,15 +90,13 @@ class TextDataset(Dataset):
except
IndexError
:
# skip ill-formed stories
continue
summary
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
summary
))
summary_seq
=
_fit_to_block_size
(
summary
,
block_size
)
story
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
story
))
story
=
tokenizer_src
.
convert_tokens_to_ids
(
tokenizer_src
.
tokenize
(
story
))
story_seq
=
_fit_to_block_size
(
story
,
block_size
)
self
.
examples
.
append
(
tokenizer
.
add_special_token_sequence_pair
(
story_seq
,
summary_seq
)
)
summary
=
tokenizer_tgt
.
convert_tokens_to_ids
(
tokenizer_tgt
.
tokenize
(
summary
))
summary_seq
=
_fit_to_block_size
(
summary
,
block_size
)
self
.
examples
.
append
((
story_seq
,
summary_seq
))
logger
.
info
(
"Saving features into cache file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
"wb"
)
as
sink
:
...
...
@@ -169,8 +167,8 @@ def _fit_to_block_size(sequence, block_size):
return
sequence
.
extend
([
-
1
]
*
[
block_size
-
len
(
sequence
)])
def
load_and_cache_examples
(
args
,
tokenizer
):
dataset
=
TextDataset
(
tokenizer
,
file_path
=
args
.
data_dir
)
def
load_and_cache_examples
(
args
,
tokenizer
_src
,
tokenizer_tgt
):
dataset
=
TextDataset
(
tokenizer
_src
,
tokenizer_tgt
,
file_path
=
args
.
data_dir
)
return
dataset
...
...
@@ -205,14 +203,35 @@ def main():
# Optional parameters
parser
.
add_argument
(
"--model_name_or_path"
,
"--decoder_name_or_path"
,
default
=
"bert-base-cased"
,
type
=
str
,
help
=
"The model checkpoint to initialize the decoder's weights with."
,
)
parser
.
add_argument
(
"--decoder_type"
,
default
=
"bert"
,
type
=
str
,
help
=
"The decoder architecture to be fine-tuned."
,
)
parser
.
add_argument
(
"--encoder_name_or_path"
,
default
=
"bert-base-cased"
,
type
=
str
,
help
=
"The model checkpoint for weights initialization."
,
help
=
"The model checkpoint to initialize the encoder's weights with."
,
)
parser
.
add_argument
(
"--encoder_type"
,
default
=
"bert"
,
type
=
str
,
help
=
"The encoder architecture to be fine-tuned."
,
)
parser
.
add_argument
(
"--seed"
,
default
=
42
,
type
=
int
)
args
=
parser
.
parse_args
()
if
args
.
encoder_type
!=
'bert'
or
args
.
decoder_type
!=
'bert'
:
raise
ValueError
(
"Only the BERT architecture is currently supported for seq2seq."
)
# Set up training device
# device = torch.device("cpu")
...
...
@@ -220,16 +239,15 @@ def main():
set_seed
(
args
)
# Load pretrained model and tokenizer
tokenizer_class
=
BertTokenizer
# config = config_class.from_pretrained(args.model_name_or_path)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
# model = model_class.from_pretrained(args.model_name_or_path, config=config)
encoder_tokenizer_class
=
AutoTokenizer
.
from_pretrained
(
args
.
encoder_name_or_path
)
decoder_tokenizer_class
=
AutoTokenizer
.
from_pretrained
(
args
.
decoder_name_or_path
)
model
=
Model2Model
.
from_pretrained
(
args
.
encoder_name_or_path
,
args
.
decoder_name_or_path
)
# model.to(device)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
# Training
_
=
load_and_cache_examples
(
args
,
tokenizer
)
source
,
target
=
load_and_cache_examples
(
args
,
tokenizer
)
# global_step, tr_loss = train(args, train_dataset, model, tokenizer)
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment