Unverified Commit fb78a90d authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

PL: --adafactor option (#6776)

parent 92ac2fa7
...@@ -22,6 +22,7 @@ from transformers import ( ...@@ -22,6 +22,7 @@ from transformers import (
PreTrainedTokenizer, PreTrainedTokenizer,
) )
from transformers.optimization import ( from transformers.optimization import (
Adafactor,
get_cosine_schedule_with_warmup, get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup, get_linear_schedule_with_warmup,
...@@ -137,7 +138,15 @@ class BaseTransformer(pl.LightningModule): ...@@ -137,7 +138,15 @@ class BaseTransformer(pl.LightningModule):
"weight_decay": 0.0, "weight_decay": 0.0,
}, },
] ]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) if self.hparams.adafactor:
optimizer = Adafactor(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
)
else:
optimizer = AdamW(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
)
self.opt = optimizer self.opt = optimizer
scheduler = self.get_lr_scheduler() scheduler = self.get_lr_scheduler()
...@@ -251,6 +260,7 @@ class BaseTransformer(pl.LightningModule): ...@@ -251,6 +260,7 @@ class BaseTransformer(pl.LightningModule):
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int) parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
parser.add_argument("--train_batch_size", default=32, type=int) parser.add_argument("--train_batch_size", default=32, type=int)
parser.add_argument("--eval_batch_size", default=32, type=int) parser.add_argument("--eval_batch_size", default=32, type=int)
parser.add_argument("--adafactor", action="store_true")
class LoggingCallback(pl.Callback): class LoggingCallback(pl.Callback):
......
...@@ -30,6 +30,7 @@ logger = logging.getLogger() ...@@ -30,6 +30,7 @@ logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available() CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = { CHEAP_ARGS = {
"label_smoothing": 0.2, "label_smoothing": 0.2,
"adafactor": True,
"early_stopping_patience": 2, "early_stopping_patience": 2,
"logger_name": "default", "logger_name": "default",
"length_penalty": 0.5, "length_penalty": 0.5,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment