Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
47a06d88
"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "2d105066df97a283fa155e39e0cc34ebbe58f55f"
Commit
47a06d88
authored
Oct 17, 2019
by
Rémi Louf
Browse files
use two different tokenizers for storyand summary
parent
bfb9b540
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
18 deletions
+36
-18
examples/run_seq2seq_finetuning.py
examples/run_seq2seq_finetuning.py
+36
-18
No files found.
examples/run_seq2seq_finetuning.py
View file @
47a06d88
...
@@ -26,7 +26,7 @@ import numpy as np
...
@@ -26,7 +26,7 @@ import numpy as np
import
torch
import
torch
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
transformers
import
Bert
Tokenizer
from
transformers
import
Auto
Tokenizer
,
Model2Model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -57,7 +57,7 @@ class TextDataset(Dataset):
...
@@ -57,7 +57,7 @@ class TextDataset(Dataset):
[2] https://github.com/abisee/cnn-dailymail/
[2] https://github.com/abisee/cnn-dailymail/
"""
"""
def
__init_
(
self
,
tokenizer
,
data_dir
=
""
,
block_size
=
512
):
def
__init_
(
self
,
tokenizer
_src
,
tokenizer_tgt
,
data_dir
=
""
,
block_size
=
512
):
assert
os
.
path
.
isdir
(
data_dir
)
assert
os
.
path
.
isdir
(
data_dir
)
# Load features that have already been computed if present
# Load features that have already been computed if present
...
@@ -90,15 +90,13 @@ class TextDataset(Dataset):
...
@@ -90,15 +90,13 @@ class TextDataset(Dataset):
except
IndexError
:
# skip ill-formed stories
except
IndexError
:
# skip ill-formed stories
continue
continue
summary
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
summary
))
story
=
tokenizer_src
.
convert_tokens_to_ids
(
tokenizer_src
.
tokenize
(
story
))
summary_seq
=
_fit_to_block_size
(
summary
,
block_size
)
story
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
story
))
story_seq
=
_fit_to_block_size
(
story
,
block_size
)
story_seq
=
_fit_to_block_size
(
story
,
block_size
)
self
.
examples
.
append
(
summary
=
tokenizer_tgt
.
convert_tokens_to_ids
(
tokenizer_tgt
.
tokenize
(
summary
))
tokenizer
.
add_special_token_sequence_pair
(
story_seq
,
summary_seq
)
summary_seq
=
_fit_to_block_size
(
summary
,
block_size
)
)
self
.
examples
.
append
((
story_seq
,
summary_seq
))
logger
.
info
(
"Saving features into cache file %s"
,
cached_features_file
)
logger
.
info
(
"Saving features into cache file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
"wb"
)
as
sink
:
with
open
(
cached_features_file
,
"wb"
)
as
sink
:
...
@@ -169,8 +167,8 @@ def _fit_to_block_size(sequence, block_size):
...
@@ -169,8 +167,8 @@ def _fit_to_block_size(sequence, block_size):
return
sequence
.
extend
([
-
1
]
*
[
block_size
-
len
(
sequence
)])
return
sequence
.
extend
([
-
1
]
*
[
block_size
-
len
(
sequence
)])
def
load_and_cache_examples
(
args
,
tokenizer
):
def
load_and_cache_examples
(
args
,
tokenizer
_src
,
tokenizer_tgt
):
dataset
=
TextDataset
(
tokenizer
,
file_path
=
args
.
data_dir
)
dataset
=
TextDataset
(
tokenizer
_src
,
tokenizer_tgt
,
file_path
=
args
.
data_dir
)
return
dataset
return
dataset
...
@@ -205,14 +203,35 @@ def main():
...
@@ -205,14 +203,35 @@ def main():
# Optional parameters
# Optional parameters
parser
.
add_argument
(
parser
.
add_argument
(
"--
m
ode
l
_name_or_path"
,
"--
dec
ode
r
_name_or_path"
,
default
=
"bert-base-cased"
,
default
=
"bert-base-cased"
,
type
=
str
,
type
=
str
,
help
=
"The model checkpoint for weights initialization."
,
help
=
"The model checkpoint to initialize the decoder's weights with."
,
)
parser
.
add_argument
(
"--decoder_type"
,
default
=
"bert"
,
type
=
str
,
help
=
"The decoder architecture to be fine-tuned."
,
)
parser
.
add_argument
(
"--encoder_name_or_path"
,
default
=
"bert-base-cased"
,
type
=
str
,
help
=
"The model checkpoint to initialize the encoder's weights with."
,
)
parser
.
add_argument
(
"--encoder_type"
,
default
=
"bert"
,
type
=
str
,
help
=
"The encoder architecture to be fine-tuned."
,
)
)
parser
.
add_argument
(
"--seed"
,
default
=
42
,
type
=
int
)
parser
.
add_argument
(
"--seed"
,
default
=
42
,
type
=
int
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
encoder_type
!=
'bert'
or
args
.
decoder_type
!=
'bert'
:
raise
ValueError
(
"Only the BERT architecture is currently supported for seq2seq."
)
# Set up training device
# Set up training device
# device = torch.device("cpu")
# device = torch.device("cpu")
...
@@ -220,16 +239,15 @@ def main():
...
@@ -220,16 +239,15 @@ def main():
set_seed
(
args
)
set_seed
(
args
)
# Load pretrained model and tokenizer
# Load pretrained model and tokenizer
tokenizer_class
=
BertTokenizer
encoder_tokenizer_class
=
AutoTokenizer
.
from_pretrained
(
args
.
encoder_name_or_path
)
# config = config_class.from_pretrained(args.model_name_or_path)
decoder_tokenizer_class
=
AutoTokenizer
.
from_pretrained
(
args
.
decoder_name_or_path
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
Model2Model
.
from_pretrained
(
args
.
encoder_name_or_path
,
args
.
decoder_name_or_path
)
# model = model_class.from_pretrained(args.model_name_or_path, config=config)
# model.to(device)
# model.to(device)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
# Training
# Training
_
=
load_and_cache_examples
(
args
,
tokenizer
)
source
,
target
=
load_and_cache_examples
(
args
,
tokenizer
)
# global_step, tr_loss = train(args, train_dataset, model, tokenizer)
# global_step, tr_loss = train(args, train_dataset, model, tokenizer)
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment