Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
578d23e0
"build_tools/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "66f4cdf9a95af341da23645f00b308d2caf9a905"
Commit
578d23e0
authored
Oct 17, 2019
by
Rémi Louf
Browse files
add training pipeline (formatting temporary)
parent
47a06d88
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
130 additions
and
9 deletions
+130
-9
examples/run_seq2seq_finetuning.py
examples/run_seq2seq_finetuning.py
+130
-9
No files found.
examples/run_seq2seq_finetuning.py
View file @
578d23e0
...
@@ -23,8 +23,9 @@ import random
...
@@ -23,8 +23,9 @@ import random
import
os
import
os
import
numpy
as
np
import
numpy
as
np
from
tqdm
import
tqdm
,
trange
import
torch
import
torch
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
,
RandomSampler
from
transformers
import
AutoTokenizer
,
Model2Model
from
transformers
import
AutoTokenizer
,
Model2Model
...
@@ -90,10 +91,14 @@ class TextDataset(Dataset):
...
@@ -90,10 +91,14 @@ class TextDataset(Dataset):
except
IndexError
:
# skip ill-formed stories
except
IndexError
:
# skip ill-formed stories
continue
continue
story
=
tokenizer_src
.
convert_tokens_to_ids
(
tokenizer_src
.
tokenize
(
story
))
story
=
tokenizer_src
.
convert_tokens_to_ids
(
tokenizer_src
.
tokenize
(
story
)
)
story_seq
=
_fit_to_block_size
(
story
,
block_size
)
story_seq
=
_fit_to_block_size
(
story
,
block_size
)
summary
=
tokenizer_tgt
.
convert_tokens_to_ids
(
tokenizer_tgt
.
tokenize
(
summary
))
summary
=
tokenizer_tgt
.
convert_tokens_to_ids
(
tokenizer_tgt
.
tokenize
(
summary
)
)
summary_seq
=
_fit_to_block_size
(
summary
,
block_size
)
summary_seq
=
_fit_to_block_size
(
summary
,
block_size
)
self
.
examples
.
append
((
story_seq
,
summary_seq
))
self
.
examples
.
append
((
story_seq
,
summary_seq
))
...
@@ -179,7 +184,89 @@ def load_and_cache_examples(args, tokenizer_src, tokenizer_tgt):
...
@@ -179,7 +184,89 @@ def load_and_cache_examples(args, tokenizer_src, tokenizer_tgt):
def
train
(
args
,
train_dataset
,
model
,
tokenizer
):
def
train
(
args
,
train_dataset
,
model
,
tokenizer
):
""" Fine-tune the pretrained model on the corpus. """
""" Fine-tune the pretrained model on the corpus. """
raise
NotImplementedError
# Prepare the data loading
args
.
train_bach_size
=
1
train_sampler
=
RandomSampler
(
train_dataset
)
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_bach_size
)
# Prepare the optimizer and schedule (linear warmup and decay)
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
optimizer_grouped_parameters
=
[
{
"params"
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
no_decay
)
],
"weight_decay"
:
args
.
weight_decay
,
},
{
"params"
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
any
(
nd
in
n
for
nd
in
no_decay
)
],
"weight_decay"
:
0.0
,
},
]
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
,
eps
=
args
.
adam_epsilon
)
scheduler
=
WarmupLinearSchedule
(
optimizer
,
warmup_steps
=
args
.
warmup_steps
,
t_total
=
t_total
)
# Train
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples = %d"
,
len
(
train_dataset
))
logger
.
info
(
" Num Epochs = %d"
,
args
.
num_train_epochs
)
logger
.
info
(
" Instantaneous batch size per GPU = %d"
,
args
.
per_gpu_train_batch_size
)
logger
.
info
(
" Total train batch size (w. parallel, distributed & accumulation) = %d"
,
args
.
train_batch_size
*
args
.
gradient_accumulation_steps
*
(
torch
.
distributed
.
get_world_size
()
if
args
.
local_rank
!=
-
1
else
1
),
)
logger
.
info
(
" Gradient Accumulation steps = %d"
,
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
global_step
=
0
tr_loss
,
logging_loss
=
0.0
,
0.0
model
.
zero_grad
()
train_iterator
=
trange
(
args
.
num_train_epochs
,
desc
=
"Epoch"
,
disable
=
True
)
set_seed
(
args
)
for
_
in
train_iterator
:
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
True
)
for
step
,
batch
in
enumerate
(
epoch_iterator
):
source
=
([
s
for
s
,
_
in
batch
]).
to
(
args
.
device
)
target
=
([
t
for
_
,
t
in
batch
]).
to
(
args
.
device
)
model
.
train
()
outputs
=
model
(
source
,
target
)
loss
=
outputs
[
0
]
loss
.
backward
()
tr_loss
+=
loss
.
item
()
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
max_grad_norm
)
optimizer
.
step
()
scheduler
.
step
()
model
.
zero_grad
()
global_step
+=
1
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
epoch_iterator
.
close
()
break
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
train_iterator
.
close
()
break
return
global_step
,
tr_loss
/
global_step
def
main
():
def
main
():
...
@@ -202,6 +289,9 @@ def main():
...
@@ -202,6 +289,9 @@ def main():
)
)
# Optional parameters
# Optional parameters
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--decoder_name_or_path"
,
"--decoder_name_or_path"
,
default
=
"bert-base-cased"
,
default
=
"bert-base-cased"
,
...
@@ -226,11 +316,40 @@ def main():
...
@@ -226,11 +316,40 @@ def main():
type
=
str
,
type
=
str
,
help
=
"The encoder architecture to be fine-tuned."
,
help
=
"The encoder architecture to be fine-tuned."
,
)
)
parser
.
add_argument
(
"--learning_rate"
,
default
=
5e-5
,
type
=
float
,
help
=
"The initial learning rate for Adam."
,
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--max_steps"
,
default
=-
1
,
type
=
int
,
help
=
"If > 0: set total number of training steps to perform. Override num_train_epochs."
,
)
parser
.
add_argument
(
"--num_train_epochs"
,
default
=
1
,
type
=
int
,
help
=
"Total number of training epochs to perform."
,
)
parser
.
add_argument
(
"--seed"
,
default
=
42
,
type
=
int
)
parser
.
add_argument
(
"--seed"
,
default
=
42
,
type
=
int
)
parser
.
add_argument
(
"--warmup_steps"
,
default
=
0
,
type
=
int
,
help
=
"Linear warmup over warmup_steps."
)
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight deay if we apply some."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
encoder_type
!=
'bert'
or
args
.
decoder_type
!=
'bert'
:
if
args
.
encoder_type
!=
"bert"
or
args
.
decoder_type
!=
"bert"
:
raise
ValueError
(
"Only the BERT architecture is currently supported for seq2seq."
)
raise
ValueError
(
"Only the BERT architecture is currently supported for seq2seq."
)
# Set up training device
# Set up training device
# device = torch.device("cpu")
# device = torch.device("cpu")
...
@@ -241,14 +360,16 @@ def main():
...
@@ -241,14 +360,16 @@ def main():
# Load pretrained model and tokenizer
# Load pretrained model and tokenizer
encoder_tokenizer_class
=
AutoTokenizer
.
from_pretrained
(
args
.
encoder_name_or_path
)
encoder_tokenizer_class
=
AutoTokenizer
.
from_pretrained
(
args
.
encoder_name_or_path
)
decoder_tokenizer_class
=
AutoTokenizer
.
from_pretrained
(
args
.
decoder_name_or_path
)
decoder_tokenizer_class
=
AutoTokenizer
.
from_pretrained
(
args
.
decoder_name_or_path
)
model
=
Model2Model
.
from_pretrained
(
args
.
encoder_name_or_path
,
args
.
decoder_name_or_path
)
model
=
Model2Model
.
from_pretrained
(
args
.
encoder_name_or_path
,
args
.
decoder_name_or_path
)
# model.to(device)
# model.to(device)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
# Training
# Training
source
,
targ
et
=
load_and_cache_examples
(
args
,
tokenizer
)
train_datas
et
=
load_and_cache_examples
(
args
,
tokenizer
)
#
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
)
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment