Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
84c265ff
Unverified
Commit
84c265ff
authored
Aug 16, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 16, 2020
Browse files
[lightning_base] fix s2s logging, only make train_loader once (#6404)
parent
72add6c9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
47 additions
and
72 deletions
+47
-72
examples/lightning_base.py
examples/lightning_base.py
+21
-9
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+1
-26
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+1
-19
examples/text-classification/run_pl_glue.py
examples/text-classification/run_pl_glue.py
+1
-8
examples/token-classification/run_pl_ner.py
examples/token-classification/run_pl_ner.py
+1
-10
examples/token-classification/test_ner_examples.py
examples/token-classification/test_ner_examples.py
+22
-0
No files found.
examples/lightning_base.py
View file @
84c265ff
...
...
@@ -150,15 +150,20 @@ class BaseTransformer(pl.LightningModule):
def
test_epoch_end
(
self
,
outputs
):
return
self
.
validation_end
(
outputs
)
def
setup
(
self
,
step
):
train_batch_size
=
self
.
hparams
.
train_batch_size
dataloader
=
self
.
get_dataloader
(
"train"
,
train_batch_size
)
self
.
train_loader
=
dataloader
self
.
total_steps
=
(
(
len
(
dataloader
.
dataset
)
//
(
train_batch_size
*
max
(
1
,
self
.
hparams
.
gpus
)))
//
self
.
hparams
.
accumulate_grad_batches
*
float
(
self
.
hparams
.
max_epochs
)
)
@
property
def
total_steps
(
self
)
->
int
:
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
num_devices
=
max
(
1
,
self
.
hparams
.
gpus
)
# TODO: consider num_tpu_cores
effective_batch_size
=
self
.
hparams
.
train_batch_size
*
self
.
hparams
.
accumulate_grad_batches
*
num_devices
dataset_size
=
len
(
self
.
train_loader
.
dataset
)
return
(
dataset_size
/
effective_batch_size
)
*
self
.
hparams
.
max_epochs
def
setup
(
self
,
mode
):
if
mode
==
"fit"
:
self
.
train_loader
=
self
.
get_dataloader
(
"train"
,
self
.
hparams
.
train_batch_size
,
shuffle
=
True
)
def
get_dataloader
(
self
,
type_path
,
batch_size
,
shuffle
=
False
):
raise
NotImplementedError
(
"You must implement this for your task"
)
def
train_dataloader
(
self
):
return
self
.
train_loader
...
...
@@ -304,6 +309,13 @@ def add_generic_args(parser, root_dir) -> None:
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"random seed for initialization"
)
parser
.
add_argument
(
"--data_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The input data dir. Should contain the training files for the CoNLL-2003 NER task."
,
)
def
generic_train
(
...
...
examples/seq2seq/distillation.py
View file @
84c265ff
...
...
@@ -10,14 +10,7 @@ from torch import nn
from
torch.nn
import
functional
as
F
from
lightning_base
import
generic_train
from
transformers
import
(
AdamW
,
BartConfig
,
BartForConditionalGeneration
,
MBartTokenizer
,
T5Config
,
T5ForConditionalGeneration
,
)
from
transformers
import
BartConfig
,
BartForConditionalGeneration
,
MBartTokenizer
,
T5Config
,
T5ForConditionalGeneration
try
:
...
...
@@ -158,24 +151,6 @@ class BartSummarizationDistiller(SummarizationModule):
)
return
loss_ce
,
s_logits_slct
,
t_logits_slct
def
configure_optimizers
(
self
):
"Prepare optimizer and schedule (linear warmup and decay)"
model
=
self
.
model
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
optimizer_grouped_parameters
=
[
{
"params"
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
no_decay
)],
"weight_decay"
:
self
.
hparams
.
weight_decay
,
},
{
"params"
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
any
(
nd
in
n
for
nd
in
no_decay
)],
"weight_decay"
:
0.0
,
},
]
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
self
.
hparams
.
learning_rate
,
eps
=
self
.
hparams
.
adam_epsilon
)
self
.
opt
=
optimizer
return
[
optimizer
]
@
staticmethod
def
add_model_specific_args
(
parser
,
root_dir
):
SummarizationModule
.
add_model_specific_args
(
parser
,
root_dir
)
...
...
examples/seq2seq/finetune.py
View file @
84c265ff
...
...
@@ -3,7 +3,6 @@ import glob
import
logging
import
os
import
time
import
warnings
from
collections
import
defaultdict
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Tuple
...
...
@@ -14,7 +13,7 @@ import torch
from
torch.utils.data
import
DataLoader
from
lightning_base
import
BaseTransformer
,
add_generic_args
,
generic_train
from
transformers
import
MarianTokenizer
,
MBartTokenizer
,
T5ForConditionalGeneration
,
get_linear_schedule_with_warmup
from
transformers
import
MarianTokenizer
,
MBartTokenizer
,
T5ForConditionalGeneration
try
:
...
...
@@ -252,17 +251,6 @@ class SummarizationModule(BaseTransformer):
def
train_dataloader
(
self
)
->
DataLoader
:
dataloader
=
self
.
get_dataloader
(
"train"
,
batch_size
=
self
.
hparams
.
train_batch_size
,
shuffle
=
True
)
t_total
=
(
(
len
(
dataloader
.
dataset
)
//
(
self
.
hparams
.
train_batch_size
*
max
(
1
,
self
.
hparams
.
gpus
)))
//
self
.
hparams
.
accumulate_grad_batches
*
float
(
self
.
hparams
.
max_epochs
)
)
scheduler
=
get_linear_schedule_with_warmup
(
self
.
opt
,
num_warmup_steps
=
self
.
hparams
.
warmup_steps
,
num_training_steps
=
t_total
)
if
max
(
scheduler
.
get_last_lr
())
>
0
:
warnings
.
warn
(
"All learning rates are 0"
)
self
.
lr_scheduler
=
scheduler
return
dataloader
def
val_dataloader
(
self
)
->
DataLoader
:
...
...
@@ -303,12 +291,6 @@ class SummarizationModule(BaseTransformer):
help
=
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
,
)
parser
.
add_argument
(
"--data_dir"
,
type
=
str
,
required
=
True
,
help
=
"The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target"
,
)
parser
.
add_argument
(
"--freeze_encoder"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--freeze_embeds"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--sortish_sampler"
,
action
=
"store_true"
,
default
=
False
)
...
...
examples/text-classification/run_pl_glue.py
View file @
84c265ff
...
...
@@ -75,7 +75,7 @@ class GLUETransformer(BaseTransformer):
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
torch
.
save
(
features
,
cached_features_file
)
def
get_dataloader
(
self
,
mode
:
int
,
batch_size
:
int
,
shuffle
:
bool
=
False
)
->
DataLoader
:
def
get_dataloader
(
self
,
mode
:
str
,
batch_size
:
int
,
shuffle
:
bool
=
False
)
->
DataLoader
:
"Load datasets. Called after prepare data."
# We test on dev set to compare to benchmarks without having to submit to GLUE server
...
...
@@ -161,13 +161,6 @@ class GLUETransformer(BaseTransformer):
type
=
int
,
help
=
"The number of GPUs allocated for this, it is by default 0 meaning none"
,
)
parser
.
add_argument
(
"--data_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The input data dir. Should contain the training files for the CoNLL-2003 NER task."
,
)
parser
.
add_argument
(
"--overwrite_cache"
,
action
=
"store_true"
,
help
=
"Overwrite the cached training and evaluation sets"
...
...
examples/token-classification/run_pl_ner.py
View file @
84c265ff
...
...
@@ -104,8 +104,7 @@ class NERTransformer(BaseTransformer):
)
def
validation_step
(
self
,
batch
,
batch_nb
):
"Compute validation"
"""Compute validation"""
""
inputs
=
{
"input_ids"
:
batch
[
0
],
"attention_mask"
:
batch
[
1
],
"labels"
:
batch
[
3
]}
if
self
.
config
.
model_type
!=
"distilbert"
:
inputs
[
"token_type_ids"
]
=
(
...
...
@@ -191,14 +190,6 @@ class NERTransformer(BaseTransformer):
help
=
"The number of GPUs allocated for this, it is by default 0 meaning none"
,
)
parser
.
add_argument
(
"--data_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The input data dir. Should contain the training files for the CoNLL-2003 NER task."
,
)
parser
.
add_argument
(
"--overwrite_cache"
,
action
=
"store_true"
,
help
=
"Overwrite the cached training and evaluation sets"
)
...
...
examples/token-classification/test_ner_examples.py
View file @
84c265ff
...
...
@@ -4,6 +4,7 @@ import unittest
from
unittest.mock
import
patch
import
run_ner
from
transformers.testing_utils
import
slow
logging
.
basicConfig
(
level
=
logging
.
INFO
)
...
...
@@ -12,6 +13,7 @@ logger = logging.getLogger()
class
ExamplesTests
(
unittest
.
TestCase
):
@
slow
def
test_run_ner
(
self
):
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
logger
.
addHandler
(
stream_handler
)
...
...
@@ -31,3 +33,23 @@ class ExamplesTests(unittest.TestCase):
with
patch
.
object
(
sys
,
"argv"
,
[
"run.py"
]
+
testargs
):
result
=
run_ner
.
main
()
self
.
assertLess
(
result
[
"eval_loss"
],
1.5
)
def
test_run_ner_pl
(
self
):
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
logger
.
addHandler
(
stream_handler
)
testargs
=
"""
--model_name distilbert-base-german-cased
--output_dir ./tests/fixtures/tests_samples/temp_dir
--overwrite_output_dir
--data_dir ./tests/fixtures/tests_samples/GermEval
--labels ./tests/fixtures/tests_samples/GermEval/labels.txt
--max_seq_length 128
--num_train_epochs 6
--logging_steps 1
--do_train
--do_eval
"""
.
split
()
with
patch
.
object
(
sys
,
"argv"
,
[
"run.py"
]
+
testargs
):
result
=
run_ner
.
main
()
self
.
assertLess
(
result
[
"eval_loss"
],
1.5
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment