"model/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "0682dae0275af6ab376cfa346ef27562b574684d"
Unverified Commit ffceef20 authored by Bhashithe Abeysinghe's avatar Bhashithe Abeysinghe Committed by GitHub
Browse files

[Fix] text-classification PL example (#6027)


Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent eb2bd8d6
...@@ -73,7 +73,7 @@ class BaseTransformer(pl.LightningModule): ...@@ -73,7 +73,7 @@ class BaseTransformer(pl.LightningModule):
# self.save_hyperparameters() # self.save_hyperparameters()
# can also expand arguments into trainer signature for easier reading # can also expand arguments into trainer signature for easier reading
self.hparams = hparams self.save_hyperparameters(hparams)
self.step_count = 0 self.step_count = 0
self.output_dir = Path(self.hparams.output_dir) self.output_dir = Path(self.hparams.output_dir)
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
...@@ -245,7 +245,7 @@ class BaseTransformer(pl.LightningModule): ...@@ -245,7 +245,7 @@ class BaseTransformer(pl.LightningModule):
class LoggingCallback(pl.Callback): class LoggingCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module): def on_batch_end(self, trainer, pl_module):
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())} lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
pl_module.logger.log_metrics(lrs) pl_module.logger.log_metrics(lrs)
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
...@@ -278,6 +278,10 @@ def add_generic_args(parser, root_dir) -> None: ...@@ -278,6 +278,10 @@ def add_generic_args(parser, root_dir) -> None:
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
parser.add_argument(
"--gpus", default=0, type=int, help="The number of GPUs allocated for this, it is by default 0 meaning none",
)
parser.add_argument( parser.add_argument(
"--fp16", "--fp16",
action="store_true", action="store_true",
...@@ -291,7 +295,7 @@ def add_generic_args(parser, root_dir) -> None: ...@@ -291,7 +295,7 @@ def add_generic_args(parser, root_dir) -> None:
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html", "See details at https://nvidia.github.io/apex/amp.html",
) )
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int, default=0) parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm") parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
parser.add_argument("--do_train", action="store_true", help="Whether to run training.") parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.") parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
......
...@@ -23,7 +23,7 @@ mkdir -p $OUTPUT_DIR ...@@ -23,7 +23,7 @@ mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access lightning_base.py # Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}" export PYTHONPATH="../":"${PYTHONPATH}"
python3 run_pl_glue.py --data_dir $DATA_DIR \ python3 run_pl_glue.py --gpus 1 --data_dir $DATA_DIR \
--task $TASK \ --task $TASK \
--model_name_or_path $BERT_MODEL \ --model_name_or_path $BERT_MODEL \
--output_dir $OUTPUT_DIR \ --output_dir $OUTPUT_DIR \
......
...@@ -3,6 +3,7 @@ import glob ...@@ -3,6 +3,7 @@ import glob
import logging import logging
import os import os
import time import time
from argparse import Namespace
import numpy as np import numpy as np
import torch import torch
...@@ -24,6 +25,8 @@ class GLUETransformer(BaseTransformer): ...@@ -24,6 +25,8 @@ class GLUETransformer(BaseTransformer):
mode = "sequence-classification" mode = "sequence-classification"
def __init__(self, hparams): def __init__(self, hparams):
if type(hparams) == dict:
hparams = Namespace(**hparams)
hparams.glue_output_mode = glue_output_modes[hparams.task] hparams.glue_output_mode = glue_output_modes[hparams.task]
num_labels = glue_tasks_num_labels[hparams.task] num_labels = glue_tasks_num_labels[hparams.task]
...@@ -41,7 +44,8 @@ class GLUETransformer(BaseTransformer): ...@@ -41,7 +44,8 @@ class GLUETransformer(BaseTransformer):
outputs = self(**inputs) outputs = self(**inputs)
loss = outputs[0] loss = outputs[0]
tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]} # tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
tensorboard_logs = {"loss": loss}
return {"loss": loss, "log": tensorboard_logs} return {"loss": loss, "log": tensorboard_logs}
def prepare_data(self): def prepare_data(self):
...@@ -71,7 +75,7 @@ class GLUETransformer(BaseTransformer): ...@@ -71,7 +75,7 @@ class GLUETransformer(BaseTransformer):
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file) torch.save(features, cached_features_file)
def load_dataset(self, mode, batch_size): def get_dataloader(self, mode: int, batch_size: int, shuffle: bool) -> DataLoader:
"Load datasets. Called after prepare data." "Load datasets. Called after prepare data."
# We test on dev set to compare to benchmarks without having to submit to GLUE server # We test on dev set to compare to benchmarks without having to submit to GLUE server
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment