Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f5c2a122
Unverified
Commit
f5c2a122
authored
Jun 22, 2020
by
Sam Shleifer
Committed by
GitHub
Jun 22, 2020
Browse files
Upgrade examples to pl=0.8.1(#5146)
parent
06b60c8b
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
53 additions
and
150 deletions
+53
-150
examples/lightning_base.py
examples/lightning_base.py
+22
-41
examples/requirements.txt
examples/requirements.txt
+1
-1
examples/summarization/callbacks.py
examples/summarization/callbacks.py
+1
-2
examples/summarization/distillation.py
examples/summarization/distillation.py
+1
-0
examples/summarization/finetune.py
examples/summarization/finetune.py
+6
-23
examples/summarization/run_distiller.sh
examples/summarization/run_distiller.sh
+0
-1
examples/summarization/run_eval.py
examples/summarization/run_eval.py
+1
-0
examples/summarization/test_summarization_examples.py
examples/summarization/test_summarization_examples.py
+16
-73
examples/summarization/utils.py
examples/summarization/utils.py
+2
-0
examples/text-classification/run_pl_glue.py
examples/text-classification/run_pl_glue.py
+2
-8
src/transformers/tokenization_auto.py
src/transformers/tokenization_auto.py
+1
-1
No files found.
examples/lightning_base.py
View file @
f5c2a122
...
@@ -8,6 +8,7 @@ from typing import Any, Dict
...
@@ -8,6 +8,7 @@ from typing import Any, Dict
import
numpy
as
np
import
numpy
as
np
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch
import
torch
from
pytorch_lightning.utilities
import
rank_zero_info
,
rank_zero_only
from
transformers
import
(
from
transformers
import
(
AdamW
,
AdamW
,
...
@@ -60,10 +61,9 @@ class BaseTransformer(pl.LightningModule):
...
@@ -60,10 +61,9 @@ class BaseTransformer(pl.LightningModule):
model
=
None
,
model
=
None
,
**
config_kwargs
**
config_kwargs
):
):
"Initialize a model."
"""Initialize a model, tokenizer and config."""
super
().
__init__
()
super
().
__init__
()
self
.
hparams
=
hparams
self
.
hparams
=
hparams
# TODO: move to self.save_hyperparameters()
self
.
step_count
=
0
self
.
step_count
=
0
self
.
tfmr_ckpts
=
{}
self
.
tfmr_ckpts
=
{}
self
.
output_dir
=
Path
(
self
.
hparams
.
output_dir
)
self
.
output_dir
=
Path
(
self
.
hparams
.
output_dir
)
...
@@ -84,8 +84,8 @@ class BaseTransformer(pl.LightningModule):
...
@@ -84,8 +84,8 @@ class BaseTransformer(pl.LightningModule):
)
)
else
:
else
:
self
.
tokenizer
:
PreTrainedTokenizer
=
tokenizer
self
.
tokenizer
:
PreTrainedTokenizer
=
tokenizer
self
.
model_type
=
MODEL_MODES
[
mode
]
if
model
is
None
:
if
model
is
None
:
self
.
model_type
=
MODEL_MODES
[
mode
]
self
.
model
=
self
.
model_type
.
from_pretrained
(
self
.
model
=
self
.
model_type
.
from_pretrained
(
self
.
hparams
.
model_name_or_path
,
self
.
hparams
.
model_name_or_path
,
from_tf
=
bool
(
".ckpt"
in
self
.
hparams
.
model_name_or_path
),
from_tf
=
bool
(
".ckpt"
in
self
.
hparams
.
model_name_or_path
),
...
@@ -93,18 +93,13 @@ class BaseTransformer(pl.LightningModule):
...
@@ -93,18 +93,13 @@ class BaseTransformer(pl.LightningModule):
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
)
)
else
:
else
:
self
.
model_type
=
None
self
.
model
=
model
self
.
model
=
model
def
load_hf_checkpoint
(
self
,
*
args
,
**
kwargs
):
def
load_hf_checkpoint
(
self
,
*
args
,
**
kwargs
):
self
.
model
=
self
.
model_type
.
from_pretrained
(
*
args
,
**
kwargs
)
self
.
model
=
self
.
model_type
.
from_pretrained
(
*
args
,
**
kwargs
)
def
is_logger
(
self
):
return
self
.
trainer
.
proc_rank
<=
0
def
configure_optimizers
(
self
):
def
configure_optimizers
(
self
):
"Prepare optimizer and schedule (linear warmup and decay)"
"Prepare optimizer and schedule (linear warmup and decay)"
model
=
self
.
model
model
=
self
.
model
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
optimizer_grouped_parameters
=
[
optimizer_grouped_parameters
=
[
...
@@ -121,23 +116,10 @@ class BaseTransformer(pl.LightningModule):
...
@@ -121,23 +116,10 @@ class BaseTransformer(pl.LightningModule):
self
.
opt
=
optimizer
self
.
opt
=
optimizer
return
[
optimizer
]
return
[
optimizer
]
def
optimizer_step
(
self
,
epoch
,
batch_idx
,
optimizer
,
optimizer_idx
,
second_order_closure
=
None
):
if
self
.
trainer
.
use_tpu
:
xm
.
optimizer_step
(
optimizer
)
else
:
optimizer
.
step
()
optimizer
.
zero_grad
()
self
.
lr_scheduler
.
step
()
def
get_tqdm_dict
(
self
):
avg_loss
=
getattr
(
self
.
trainer
,
"avg_loss"
,
0.0
)
tqdm_dict
=
{
"loss"
:
"{:.3f}"
.
format
(
avg_loss
),
"lr"
:
self
.
lr_scheduler
.
get_last_lr
()[
-
1
]}
return
tqdm_dict
def
test_step
(
self
,
batch
,
batch_nb
):
def
test_step
(
self
,
batch
,
batch_nb
):
return
self
.
validation_step
(
batch
,
batch_nb
)
return
self
.
validation_step
(
batch
,
batch_nb
)
def
test_end
(
self
,
outputs
):
def
test_
epoch_
end
(
self
,
outputs
):
return
self
.
validation_end
(
outputs
)
return
self
.
validation_end
(
outputs
)
def
train_dataloader
(
self
):
def
train_dataloader
(
self
):
...
@@ -208,6 +190,7 @@ class BaseTransformer(pl.LightningModule):
...
@@ -208,6 +190,7 @@ class BaseTransformer(pl.LightningModule):
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight decay if we apply some."
)
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight decay if we apply some."
)
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
)
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
)
parser
.
add_argument
(
"--warmup_steps"
,
default
=
500
,
type
=
int
,
help
=
"Linear warmup over warmup_steps."
)
parser
.
add_argument
(
"--warmup_steps"
,
default
=
500
,
type
=
int
,
help
=
"Linear warmup over warmup_steps."
)
parser
.
add_argument
(
"--num_workers"
,
default
=
4
,
type
=
int
,
help
=
"kwarg passed to DataLoader"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_train_epochs"
,
default
=
3
,
type
=
int
,
help
=
"Total number of training epochs to perform."
"--num_train_epochs"
,
default
=
3
,
type
=
int
,
help
=
"Total number of training epochs to perform."
)
)
...
@@ -217,28 +200,26 @@ class BaseTransformer(pl.LightningModule):
...
@@ -217,28 +200,26 @@ class BaseTransformer(pl.LightningModule):
class
LoggingCallback
(
pl
.
Callback
):
class
LoggingCallback
(
pl
.
Callback
):
@
rank_zero_only
def
on_validation_end
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
):
def
on_validation_end
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
):
logger
.
info
(
"***** Validation results *****"
)
rank_zero_info
(
"***** Validation results *****"
)
if
pl_module
.
is_logger
():
metrics
=
trainer
.
callback_metrics
metrics
=
trainer
.
callback_metrics
# Log results
# Log results
for
key
in
sorted
(
metrics
):
if
key
not
in
[
"log"
,
"progress_bar"
]:
rank_zero_info
(
"{} = {}
\n
"
.
format
(
key
,
str
(
metrics
[
key
])))
@
rank_zero_only
def
on_test_end
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
):
logger
.
info
(
"***** Test results *****"
)
metrics
=
trainer
.
callback_metrics
# Log and save results to file
output_test_results_file
=
os
.
path
.
join
(
pl_module
.
hparams
.
output_dir
,
"test_results.txt"
)
with
open
(
output_test_results_file
,
"w"
)
as
writer
:
for
key
in
sorted
(
metrics
):
for
key
in
sorted
(
metrics
):
if
key
not
in
[
"log"
,
"progress_bar"
]:
if
key
not
in
[
"log"
,
"progress_bar"
]:
logger
.
info
(
"{} = {}
\n
"
.
format
(
key
,
str
(
metrics
[
key
])))
logger
.
info
(
"{} = {}
\n
"
.
format
(
key
,
str
(
metrics
[
key
])))
writer
.
write
(
"{} = {}
\n
"
.
format
(
key
,
str
(
metrics
[
key
])))
def
on_test_end
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
):
logger
.
info
(
"***** Test results *****"
)
if
pl_module
.
is_logger
():
metrics
=
trainer
.
callback_metrics
# Log and save results to file
output_test_results_file
=
os
.
path
.
join
(
pl_module
.
hparams
.
output_dir
,
"test_results.txt"
)
with
open
(
output_test_results_file
,
"w"
)
as
writer
:
for
key
in
sorted
(
metrics
):
if
key
not
in
[
"log"
,
"progress_bar"
]:
logger
.
info
(
"{} = {}
\n
"
.
format
(
key
,
str
(
metrics
[
key
])))
writer
.
write
(
"{} = {}
\n
"
.
format
(
key
,
str
(
metrics
[
key
])))
def
add_generic_args
(
parser
,
root_dir
)
->
None
:
def
add_generic_args
(
parser
,
root_dir
)
->
None
:
...
...
examples/requirements.txt
View file @
f5c2a122
...
@@ -5,7 +5,7 @@ psutil
...
@@ -5,7 +5,7 @@ psutil
sacrebleu
sacrebleu
rouge-score
rouge-score
tensorflow_datasets
tensorflow_datasets
pytorch-lightning==0.
7.6
pytorch-lightning==0.
8.1
matplotlib
matplotlib
git-python==1.0.3
git-python==1.0.3
faiss
faiss
...
...
examples/summarization/callbacks.py
View file @
f5c2a122
...
@@ -19,12 +19,11 @@ logger = logging.getLogger(__name__)
...
@@ -19,12 +19,11 @@ logger = logging.getLogger(__name__)
class
Seq2SeqLoggingCallback
(
pl
.
Callback
):
class
Seq2SeqLoggingCallback
(
pl
.
Callback
):
@
rank_zero_only
def
_write_logs
(
def
_write_logs
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
,
type_path
:
str
,
save_generations
=
True
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
,
type_path
:
str
,
save_generations
=
True
)
->
None
:
)
->
None
:
logger
.
info
(
f
"*****
{
type_path
}
results at step
{
trainer
.
global_step
:
05
d
}
*****"
)
logger
.
info
(
f
"*****
{
type_path
}
results at step
{
trainer
.
global_step
:
05
d
}
*****"
)
if
not
pl_module
.
is_logger
():
return
metrics
=
trainer
.
callback_metrics
metrics
=
trainer
.
callback_metrics
trainer
.
logger
.
log_metrics
({
k
:
v
for
k
,
v
in
metrics
.
items
()
if
k
not
in
[
"log"
,
"progress_bar"
,
"preds"
]})
trainer
.
logger
.
log_metrics
({
k
:
v
for
k
,
v
in
metrics
.
items
()
if
k
not
in
[
"log"
,
"progress_bar"
,
"preds"
]})
# Log results
# Log results
...
...
examples/summarization/distillation.py
View file @
f5c2a122
...
@@ -271,6 +271,7 @@ class SummarizationDistiller(SummarizationModule):
...
@@ -271,6 +271,7 @@ class SummarizationDistiller(SummarizationModule):
class
T5SummarizationDistiller
(
SummarizationDistiller
):
class
T5SummarizationDistiller
(
SummarizationDistiller
):
def
pre_init
(
self
,
hparams
):
def
pre_init
(
self
,
hparams
):
raise
NotImplementedError
(
"T5 Distillation does not work yet"
)
teacher
=
T5ForConditionalGeneration
.
from_pretrained
(
hparams
.
teacher
)
teacher
=
T5ForConditionalGeneration
.
from_pretrained
(
hparams
.
teacher
)
n_layer
=
hparams
.
student_decoder_layers
n_layer
=
hparams
.
student_decoder_layers
assert
n_layer
==
hparams
.
student_encoder_layers
# TODO(SS): relax this
assert
n_layer
==
hparams
.
student_encoder_layers
# TODO(SS): relax this
...
...
examples/summarization/finetune.py
View file @
f5c2a122
...
@@ -85,7 +85,7 @@ class SummarizationModule(BaseTransformer):
...
@@ -85,7 +85,7 @@ class SummarizationModule(BaseTransformer):
if
self
.
hparams
.
freeze_encoder
:
if
self
.
hparams
.
freeze_encoder
:
freeze_params
(
self
.
model
.
model
.
encoder
)
# TODO: this will break for t5
freeze_params
(
self
.
model
.
model
.
encoder
)
# TODO: this will break for t5
self
.
hparams
.
git_sha
=
get_git_info
()[
"repo_sha"
]
self
.
hparams
.
git_sha
=
get_git_info
()[
"repo_sha"
]
self
.
num_workers
=
4
if
self
.
hparams
.
gpus
<=
1
else
None
# passing num_workers breaks lightning for multigpu
self
.
num_workers
=
hparams
.
num_workers
def
freeze_embeds
(
self
):
def
freeze_embeds
(
self
):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
...
@@ -126,7 +126,7 @@ class SummarizationModule(BaseTransformer):
...
@@ -126,7 +126,7 @@ class SummarizationModule(BaseTransformer):
def
validation_step
(
self
,
batch
,
batch_idx
)
->
Dict
:
def
validation_step
(
self
,
batch
,
batch_idx
)
->
Dict
:
return
self
.
_generative_step
(
batch
)
return
self
.
_generative_step
(
batch
)
def
validation_end
(
self
,
outputs
,
prefix
=
"val"
)
->
Dict
:
def
validation_
epoch_
end
(
self
,
outputs
,
prefix
=
"val"
)
->
Dict
:
self
.
step_count
+=
1
self
.
step_count
+=
1
losses
=
{
k
:
torch
.
stack
([
x
[
k
]
for
x
in
outputs
]).
mean
()
for
k
in
self
.
loss_names
}
losses
=
{
k
:
torch
.
stack
([
x
[
k
]
for
x
in
outputs
]).
mean
()
for
k
in
self
.
loss_names
}
loss
=
losses
[
"loss"
]
loss
=
losses
[
"loss"
]
...
@@ -144,14 +144,12 @@ class SummarizationModule(BaseTransformer):
...
@@ -144,14 +144,12 @@ class SummarizationModule(BaseTransformer):
self
.
metrics
[
prefix
].
append
(
metrics
)
self
.
metrics
[
prefix
].
append
(
metrics
)
pickle_save
(
self
.
metrics
,
self
.
metrics_save_path
)
pickle_save
(
self
.
metrics
,
self
.
metrics_save_path
)
def
_generative_step
(
self
,
batch
)
:
def
_generative_step
(
self
,
batch
:
dict
)
->
dict
:
pad_token_id
=
self
.
tokenizer
.
pad_token_id
pad_token_id
=
self
.
tokenizer
.
pad_token_id
source_ids
,
source_mask
,
y
=
SummarizationDataset
.
trim_seq2seq_batch
(
batch
,
pad_token_id
)
source_ids
,
source_mask
,
y
=
SummarizationDataset
.
trim_seq2seq_batch
(
batch
,
pad_token_id
)
# TODO(SS): task specific params
t0
=
time
.
time
()
t0
=
time
.
time
()
generated_ids
=
self
.
model
.
generate
(
input_ids
=
source_ids
,
attention_mask
=
source_mask
,
use_cache
=
True
,)
generated_ids
=
self
.
model
.
generate
(
input_ids
=
source_ids
,
attention_mask
=
source_mask
,
use_cache
=
True
,)
gen_time
=
time
.
time
()
-
t0
gen_time
=
time
.
time
()
-
t0
/
source_ids
.
shape
[
0
]
preds
=
self
.
ids_to_clean_text
(
generated_ids
)
preds
=
self
.
ids_to_clean_text
(
generated_ids
)
target
=
self
.
ids_to_clean_text
(
y
)
target
=
self
.
ids_to_clean_text
(
y
)
loss_tensors
=
self
.
_step
(
batch
)
loss_tensors
=
self
.
_step
(
batch
)
...
@@ -164,24 +162,8 @@ class SummarizationModule(BaseTransformer):
...
@@ -164,24 +162,8 @@ class SummarizationModule(BaseTransformer):
def
test_step
(
self
,
batch
,
batch_idx
):
def
test_step
(
self
,
batch
,
batch_idx
):
return
self
.
_generative_step
(
batch
)
return
self
.
_generative_step
(
batch
)
def
test_end
(
self
,
outputs
):
return
self
.
validation_end
(
outputs
,
prefix
=
"test"
)
def
test_epoch_end
(
self
,
outputs
):
def
test_epoch_end
(
self
,
outputs
):
output_test_predictions_file
=
os
.
path
.
join
(
self
.
hparams
.
output_dir
,
"test_predictions.txt"
)
return
self
.
validation_epoch_end
(
outputs
,
prefix
=
"test"
)
output_test_targets_file
=
os
.
path
.
join
(
self
.
hparams
.
output_dir
,
"test_targets.txt"
)
# write predictions and targets for later rouge evaluation.
with
open
(
output_test_predictions_file
,
"w+"
)
as
p_writer
,
open
(
output_test_targets_file
,
"w+"
)
as
t_writer
:
for
output_batch
in
outputs
:
p_writer
.
writelines
(
s
+
"
\n
"
for
s
in
output_batch
[
"preds"
])
t_writer
.
writelines
(
s
+
"
\n
"
for
s
in
output_batch
[
"target"
])
p_writer
.
close
()
t_writer
.
close
()
return
self
.
test_end
(
outputs
)
def
validation_epoch_end
(
self
,
outputs
):
self
.
validation_end
(
outputs
,
"val"
)
def
get_dataset
(
self
,
type_path
)
->
SummarizationDataset
:
def
get_dataset
(
self
,
type_path
)
->
SummarizationDataset
:
n_obs
=
self
.
n_obs
[
type_path
]
n_obs
=
self
.
n_obs
[
type_path
]
...
@@ -310,6 +292,7 @@ def main(args, model=None) -> SummarizationModule:
...
@@ -310,6 +292,7 @@ def main(args, model=None) -> SummarizationModule:
logger
=
logger
,
logger
=
logger
,
# TODO: early stopping callback seems messed up
# TODO: early stopping callback seems messed up
)
)
pickle_save
(
model
.
hparams
,
model
.
output_dir
/
"hparams.pkl"
)
if
not
args
.
do_predict
:
if
not
args
.
do_predict
:
return
model
return
model
...
...
examples/summarization/run_distiller.sh
View file @
f5c2a122
...
@@ -7,6 +7,5 @@ python distillation.py \
...
@@ -7,6 +7,5 @@ python distillation.py \
--learning_rate
=
3e-4
\
--learning_rate
=
3e-4
\
--do_train
\
--do_train
\
--do_predict
\
--do_predict
\
--fp16
\
--val_check_interval
0.1
\
--val_check_interval
0.1
\
$@
$@
examples/summarization/run_eval.py
View file @
f5c2a122
...
@@ -26,6 +26,7 @@ def generate_summaries(
...
@@ -26,6 +26,7 @@ def generate_summaries(
examples
:
list
,
out_file
:
str
,
model_name
:
str
,
batch_size
:
int
=
8
,
device
:
str
=
DEFAULT_DEVICE
,
fp16
=
False
,
examples
:
list
,
out_file
:
str
,
model_name
:
str
,
batch_size
:
int
=
8
,
device
:
str
=
DEFAULT_DEVICE
,
fp16
=
False
,
)
->
None
:
)
->
None
:
fout
=
Path
(
out_file
).
open
(
"w"
,
encoding
=
"utf-8"
)
fout
=
Path
(
out_file
).
open
(
"w"
,
encoding
=
"utf-8"
)
model_name
=
str
(
model_name
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_name
).
to
(
device
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_name
).
to
(
device
)
if
fp16
:
if
fp16
:
model
=
model
.
half
()
model
=
model
.
half
()
...
...
examples/summarization/test_summarization_examples.py
View file @
f5c2a122
...
@@ -24,6 +24,7 @@ logger = logging.getLogger()
...
@@ -24,6 +24,7 @@ logger = logging.getLogger()
FP16_EVER
=
False
FP16_EVER
=
False
CHEAP_ARGS
=
{
CHEAP_ARGS
=
{
"logger"
:
"default"
,
"logger"
:
"default"
,
"num_workers"
:
2
,
"alpha_hid"
:
0
,
"alpha_hid"
:
0
,
"freeze_embeds"
:
True
,
"freeze_embeds"
:
True
,
"enc_only"
:
False
,
"enc_only"
:
False
,
...
@@ -79,7 +80,8 @@ def _dump_articles(path: Path, articles: list):
...
@@ -79,7 +80,8 @@ def _dump_articles(path: Path, articles: list):
f
.
write
(
"
\n
"
.
join
(
articles
))
f
.
write
(
"
\n
"
.
join
(
articles
))
BDIR
=
Path
(
"~/transformers_fork/examples/summarization/bart/"
).
absolute
()
MSG
=
"T5 is broken at the moment"
T5_TINY
=
"patrickvonplaten/t5-tiny-random"
def
make_test_data_dir
():
def
make_test_data_dir
():
...
@@ -92,7 +94,6 @@ def make_test_data_dir():
...
@@ -92,7 +94,6 @@ def make_test_data_dir():
return
tmp_dir
return
tmp_dir
@
unittest
.
skip
(
"These wont' pass until hidden_states kwarg is merged."
)
class
TestSummarizationDistiller
(
unittest
.
TestCase
):
class
TestSummarizationDistiller
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
@@ -108,47 +109,22 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -108,47 +109,22 @@ class TestSummarizationDistiller(unittest.TestCase):
freeze_encoder
=
True
,
freeze_encoder
=
True
,
gpus
=
2
,
gpus
=
2
,
sortish_sampler
=
False
,
sortish_sampler
=
False
,
)
self
.
_bart_distiller_cli
(
updates
)
@
unittest
.
skipUnless
(
torch
.
cuda
.
is_available
(),
"skipping fp16 test"
)
def
test_bdc_fp16
(
self
):
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
alpha_hid
=
3.0
,
freeze_encoder
=
True
,
gpus
=
1
,
fp16
=
FP16_EVER
,
fp16_opt_level
=
"O1"
,
fp16_opt_level
=
"O1"
,
)
self
.
_bart_distiller_cli
(
updates
)
@
unittest
.
skipUnless
(
torch
.
cuda
.
is_available
(),
"skipping fp16 test"
)
def
test_bdc_t5_eval_fp16
(
self
):
updates
=
dict
(
fp16
=
FP16_EVER
,
fp16
=
FP16_EVER
,
gpus
=
1
,
model_type
=
"t5"
,
model_name_or_path
=
"patrickvonplaten/t5-tiny-random"
,
do_train
=
False
,
do_predict
=
True
,
tokenizer_name
=
None
,
no_teacher
=
True
,
)
)
self
.
_bart_distiller_cli
(
updates
,
check_contents
=
False
)
self
.
_bart_distiller_cli
(
updates
)
@
unittest
.
skipUnless
(
torch
.
cuda
.
is_available
(),
"skipping fp16 test"
)
def
test_bdc_t5_train
(
self
):
def
test_bdc_t5_train_fp16
(
self
):
updates
=
dict
(
updates
=
dict
(
fp16
=
FP16_EVER
,
fp16
=
FP16_EVER
,
gpus
=
1
,
gpus
=
1
if
torch
.
cuda
.
is_available
()
else
0
,
model_type
=
"t5"
,
model_type
=
"t5"
,
model_name_or_path
=
"patrickvonplaten/t5-tiny-random"
,
model_name_or_path
=
T5_TINY
,
do_train
=
True
,
do_train
=
True
,
do_predict
=
True
,
do_predict
=
True
,
tokenizer_name
=
"patrickvonplaten/t5-tiny-random"
,
tokenizer_name
=
T5_TINY
,
no_teacher
=
True
,
no_teacher
=
True
,
alpha_hid
=
2.0
,
)
)
self
.
_bart_distiller_cli
(
updates
)
self
.
_bart_distiller_cli
(
updates
)
...
@@ -161,7 +137,6 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -161,7 +137,6 @@ class TestSummarizationDistiller(unittest.TestCase):
self
.
_bart_distiller_cli
(
updates
)
self
.
_bart_distiller_cli
(
updates
)
def
test_bdc_checkpointing
(
self
):
def
test_bdc_checkpointing
(
self
):
updates
=
dict
(
updates
=
dict
(
student_encoder_layers
=
2
,
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
student_decoder_layers
=
1
,
...
@@ -184,32 +159,8 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -184,32 +159,8 @@ class TestSummarizationDistiller(unittest.TestCase):
evaluate_checkpoint
(
ckpts
[
0
],
dest_dir
=
Path
(
tempfile
.
mkdtemp
()))
evaluate_checkpoint
(
ckpts
[
0
],
dest_dir
=
Path
(
tempfile
.
mkdtemp
()))
def
test_bdc_t5
(
self
):
updates
=
dict
(
student_encoder_layers
=
1
,
student_decoder_layers
=
1
,
alpha_hid
=
2.0
,
teacher
=
"patrickvonplaten/t5-tiny-random"
,
model_type
=
"t5"
,
model_name_or_path
=
"patrickvonplaten/t5-tiny-random"
,
tokenizer_name
=
"patrickvonplaten/t5-tiny-random"
,
)
self
.
_bart_distiller_cli
(
updates
)
def
test_bdc_t5_eval
(
self
):
updates
=
dict
(
model_type
=
"t5"
,
model_name_or_path
=
"patrickvonplaten/t5-tiny-random"
,
do_train
=
False
,
do_predict
=
True
,
tokenizer_name
=
"patrickvonplaten/t5-tiny-random"
,
no_teacher
=
True
,
)
self
.
_bart_distiller_cli
(
updates
,
check_contents
=
False
)
def
_bart_distiller_cli
(
self
,
updates
,
check_contents
=
True
):
def
_bart_distiller_cli
(
self
,
updates
,
check_contents
=
True
):
default_updates
=
dict
(
default_updates
=
dict
(
model_type
=
"bart"
,
train_batch_size
=
1
,
train_batch_size
=
1
,
eval_batch_size
=
2
,
eval_batch_size
=
2
,
num_train_epochs
=
2
,
num_train_epochs
=
2
,
...
@@ -237,21 +188,14 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -237,21 +188,14 @@ class TestSummarizationDistiller(unittest.TestCase):
self
.
assertIn
(
ckpt_name
,
contents
)
self
.
assertIn
(
ckpt_name
,
contents
)
self
.
assertIn
(
"metrics.pkl"
,
contents
)
self
.
assertIn
(
"metrics.pkl"
,
contents
)
self
.
assertIn
(
"test_generations.txt"
,
contents
)
self
.
assertIn
(
"test_generations.txt"
,
contents
)
self
.
assertIn
(
"val_generations_1.txt"
,
contents
)
self
.
assertIn
(
"val_generations_
0000
1.txt"
,
contents
)
self
.
assertIn
(
"val_
1_
results.txt"
,
contents
)
self
.
assertIn
(
"val_results
_00001
.txt"
,
contents
)
self
.
assertIn
(
"test_results.txt"
,
contents
)
self
.
assertIn
(
"test_results.txt"
,
contents
)
# self.assertEqual(len(contents), 15)
metrics
=
pickle_load
(
Path
(
output_dir
)
/
"metrics.pkl"
)
metrics
=
pickle_load
(
Path
(
output_dir
)
/
"metrics.pkl"
)
import
pandas
as
pd
desired_n_evals
=
int
(
args_d
[
"num_train_epochs"
]
*
(
1
/
args_d
[
"val_check_interval"
])
+
1
)
self
.
assertEqual
(
len
(
metrics
[
"val"
]),
desired_n_evals
)
val_df
=
pd
.
DataFrame
(
metrics
[
"val"
])
self
.
assertEqual
(
len
(
metrics
[
"train"
]),
0
)
# doesn't get logged here
train_df
=
pd
.
DataFrame
(
metrics
[
"train"
])
test_df
=
pd
.
DataFrame
(
metrics
[
"test"
])
desired_n_evals
=
args_d
[
"num_train_epochs"
]
*
2
+
1
self
.
assertEqual
(
val_df
.
shape
[
0
],
desired_n_evals
)
#
self
.
assertEqual
(
test_df
.
shape
[
1
],
val_df
.
shape
[
1
])
self
.
assertEqual
(
train_df
.
shape
[
0
],
0
)
return
model
return
model
...
@@ -281,9 +225,8 @@ class TestBartExamples(unittest.TestCase):
...
@@ -281,9 +225,8 @@ class TestBartExamples(unittest.TestCase):
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_"
)
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_"
)
args_d
.
update
(
args_d
.
update
(
data_dir
=
tmp_dir
,
data_dir
=
tmp_dir
,
model_type
=
"t5"
,
model_name_or_path
=
T5_TINY
,
model_name_or_path
=
"patrickvonplaten/t5-tiny-random"
,
tokenizer_name
=
None
,
# T5_TINY,
tokenizer_name
=
None
,
# "patrickvonplaten/t5-tiny-random",
train_batch_size
=
2
,
train_batch_size
=
2
,
eval_batch_size
=
2
,
eval_batch_size
=
2
,
gpus
=
0
,
gpus
=
0
,
...
...
examples/summarization/utils.py
View file @
f5c2a122
...
@@ -45,8 +45,10 @@ def encode_file(
...
@@ -45,8 +45,10 @@ def encode_file(
max_length
=
max_length
,
max_length
=
max_length
,
pad_to_max_length
=
pad_to_max_length
,
pad_to_max_length
=
pad_to_max_length
,
add_prefix_space
=
True
,
add_prefix_space
=
True
,
truncation
=
True
,
return_tensors
=
return_tensors
,
return_tensors
=
return_tensors
,
)
)
assert
tokenized
.
input_ids
.
shape
[
1
]
==
max_length
examples
.
append
(
tokenized
)
examples
.
append
(
tokenized
)
torch
.
save
(
lmap
(
dict
,
examples
),
cache_path
.
open
(
"wb"
))
torch
.
save
(
lmap
(
dict
,
examples
),
cache_path
.
open
(
"wb"
))
return
examples
return
examples
...
...
examples/text-classification/run_pl_glue.py
View file @
f5c2a122
...
@@ -108,7 +108,7 @@ class GLUETransformer(BaseTransformer):
...
@@ -108,7 +108,7 @@ class GLUETransformer(BaseTransformer):
return
{
"val_loss"
:
tmp_eval_loss
.
detach
().
cpu
(),
"pred"
:
preds
,
"target"
:
out_label_ids
}
return
{
"val_loss"
:
tmp_eval_loss
.
detach
().
cpu
(),
"pred"
:
preds
,
"target"
:
out_label_ids
}
def
_eval_end
(
self
,
outputs
):
def
_eval_end
(
self
,
outputs
)
->
tuple
:
val_loss_mean
=
torch
.
stack
([
x
[
"val_loss"
]
for
x
in
outputs
]).
mean
().
detach
().
cpu
().
item
()
val_loss_mean
=
torch
.
stack
([
x
[
"val_loss"
]
for
x
in
outputs
]).
mean
().
detach
().
cpu
().
item
()
preds
=
np
.
concatenate
([
x
[
"pred"
]
for
x
in
outputs
],
axis
=
0
)
preds
=
np
.
concatenate
([
x
[
"pred"
]
for
x
in
outputs
],
axis
=
0
)
...
@@ -132,20 +132,14 @@ class GLUETransformer(BaseTransformer):
...
@@ -132,20 +132,14 @@ class GLUETransformer(BaseTransformer):
logs
=
ret
[
"log"
]
logs
=
ret
[
"log"
]
return
{
"val_loss"
:
logs
[
"val_loss"
],
"log"
:
logs
,
"progress_bar"
:
logs
}
return
{
"val_loss"
:
logs
[
"val_loss"
],
"log"
:
logs
,
"progress_bar"
:
logs
}
def
test_epoch_end
(
self
,
outputs
):
def
test_epoch_end
(
self
,
outputs
)
->
dict
:
# updating to test_epoch_end instead of deprecated test_end
ret
,
predictions
,
targets
=
self
.
_eval_end
(
outputs
)
ret
,
predictions
,
targets
=
self
.
_eval_end
(
outputs
)
# Converting to the dic required by pl
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
# pytorch_lightning/trainer/logging.py#L139
logs
=
ret
[
"log"
]
logs
=
ret
[
"log"
]
# `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
# `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
return
{
"avg_test_loss"
:
logs
[
"val_loss"
],
"log"
:
logs
,
"progress_bar"
:
logs
}
return
{
"avg_test_loss"
:
logs
[
"val_loss"
],
"log"
:
logs
,
"progress_bar"
:
logs
}
@
staticmethod
@
staticmethod
def
add_model_specific_args
(
parser
,
root_dir
):
def
add_model_specific_args
(
parser
,
root_dir
):
# Add NER specific options
BaseTransformer
.
add_model_specific_args
(
parser
,
root_dir
)
BaseTransformer
.
add_model_specific_args
(
parser
,
root_dir
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max_seq_length"
,
"--max_seq_length"
,
...
...
src/transformers/tokenization_auto.py
View file @
f5c2a122
...
@@ -205,7 +205,7 @@ class AutoTokenizer:
...
@@ -205,7 +205,7 @@ class AutoTokenizer:
if
not
isinstance
(
config
,
PretrainedConfig
):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
"bert-base-japanese"
in
pretrained_model_name_or_path
:
if
"bert-base-japanese"
in
str
(
pretrained_model_name_or_path
)
:
return
BertJapaneseTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
return
BertJapaneseTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
use_fast
=
kwargs
.
pop
(
"use_fast"
,
False
)
use_fast
=
kwargs
.
pop
(
"use_fast"
,
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment