Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ffceef20
Unverified
Commit
ffceef20
authored
Aug 06, 2020
by
Bhashithe Abeysinghe
Committed by
GitHub
Aug 06, 2020
Browse files
[Fix] text-classification PL example (#6027)
Co-authored-by:
Sam Shleifer
<
sshleifer@gmail.com
>
parent
eb2bd8d6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
6 deletions
+14
-6
examples/lightning_base.py
examples/lightning_base.py
+7
-3
examples/text-classification/run_pl.sh
examples/text-classification/run_pl.sh
+1
-1
examples/text-classification/run_pl_glue.py
examples/text-classification/run_pl_glue.py
+6
-2
No files found.
examples/lightning_base.py
View file @
ffceef20
...
...
@@ -73,7 +73,7 @@ class BaseTransformer(pl.LightningModule):
# self.save_hyperparameters()
# can also expand arguments into trainer signature for easier reading
self
.
hparams
=
hparams
self
.
save_hyperparameters
(
hparams
)
self
.
step_count
=
0
self
.
output_dir
=
Path
(
self
.
hparams
.
output_dir
)
cache_dir
=
self
.
hparams
.
cache_dir
if
self
.
hparams
.
cache_dir
else
None
...
...
@@ -245,7 +245,7 @@ class BaseTransformer(pl.LightningModule):
class
LoggingCallback
(
pl
.
Callback
):
def
on_batch_end
(
self
,
trainer
,
pl_module
):
lrs
=
{
f
"lr_group_
{
i
}
"
:
lr
for
i
,
lr
in
enumerate
(
self
.
lr_scheduler
.
get_lr
()
)}
lrs
=
{
f
"lr_group_
{
i
}
"
:
param
[
"lr"
]
for
i
,
param
in
enumerate
(
pl_module
.
trainer
.
optimizers
[
0
].
param_groups
)}
pl_module
.
logger
.
log_metrics
(
lrs
)
def
on_validation_end
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
):
...
...
@@ -278,6 +278,10 @@ def add_generic_args(parser, root_dir) -> None:
help
=
"The output directory where the model predictions and checkpoints will be written."
,
)
parser
.
add_argument
(
"--gpus"
,
default
=
0
,
type
=
int
,
help
=
"The number of GPUs allocated for this, it is by default 0 meaning none"
,
)
parser
.
add_argument
(
"--fp16"
,
action
=
"store_true"
,
...
...
@@ -291,7 +295,7 @@ def add_generic_args(parser, root_dir) -> None:
help
=
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html"
,
)
parser
.
add_argument
(
"--n_tpu_cores"
,
dest
=
"tpu_cores"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--n_tpu_cores"
,
dest
=
"tpu_cores"
,
type
=
int
)
parser
.
add_argument
(
"--max_grad_norm"
,
dest
=
"gradient_clip_val"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm"
)
parser
.
add_argument
(
"--do_train"
,
action
=
"store_true"
,
help
=
"Whether to run training."
)
parser
.
add_argument
(
"--do_predict"
,
action
=
"store_true"
,
help
=
"Whether to run predictions on the test set."
)
...
...
examples/text-classification/run_pl.sh
View file @
ffceef20
...
...
@@ -23,7 +23,7 @@ mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access lightning_base.py
export
PYTHONPATH
=
"../"
:
"
${
PYTHONPATH
}
"
python3 run_pl_glue.py
--data_dir
$DATA_DIR
\
python3 run_pl_glue.py
--gpus
1
--data_dir
$DATA_DIR
\
--task
$TASK
\
--model_name_or_path
$BERT_MODEL
\
--output_dir
$OUTPUT_DIR
\
...
...
examples/text-classification/run_pl_glue.py
View file @
ffceef20
...
...
@@ -3,6 +3,7 @@ import glob
import
logging
import
os
import
time
from
argparse
import
Namespace
import
numpy
as
np
import
torch
...
...
@@ -24,6 +25,8 @@ class GLUETransformer(BaseTransformer):
mode
=
"sequence-classification"
def
__init__
(
self
,
hparams
):
if
type
(
hparams
)
==
dict
:
hparams
=
Namespace
(
**
hparams
)
hparams
.
glue_output_mode
=
glue_output_modes
[
hparams
.
task
]
num_labels
=
glue_tasks_num_labels
[
hparams
.
task
]
...
...
@@ -41,7 +44,8 @@ class GLUETransformer(BaseTransformer):
outputs
=
self
(
**
inputs
)
loss
=
outputs
[
0
]
tensorboard_logs
=
{
"loss"
:
loss
,
"rate"
:
self
.
lr_scheduler
.
get_last_lr
()[
-
1
]}
# tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
tensorboard_logs
=
{
"loss"
:
loss
}
return
{
"loss"
:
loss
,
"log"
:
tensorboard_logs
}
def
prepare_data
(
self
):
...
...
@@ -71,7 +75,7 @@ class GLUETransformer(BaseTransformer):
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
torch
.
save
(
features
,
cached_features_file
)
def
load
_data
set
(
self
,
mode
,
batch_size
)
:
def
get
_data
loader
(
self
,
mode
:
int
,
batch_size
:
int
,
shuffle
:
bool
)
->
DataLoader
:
"Load datasets. Called after prepare data."
# We test on dev set to compare to benchmarks without having to submit to GLUE server
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment