Unverified Commit fb78a90d authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

PL: --adafactor option (#6776)

parent 92ac2fa7
......@@ -22,6 +22,7 @@ from transformers import (
PreTrainedTokenizer,
)
from transformers.optimization import (
Adafactor,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
......@@ -137,7 +138,15 @@ class BaseTransformer(pl.LightningModule):
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
if self.hparams.adafactor:
optimizer = Adafactor(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
)
else:
optimizer = AdamW(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
)
self.opt = optimizer
scheduler = self.get_lr_scheduler()
......@@ -251,6 +260,7 @@ class BaseTransformer(pl.LightningModule):
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
parser.add_argument("--train_batch_size", default=32, type=int)
parser.add_argument("--eval_batch_size", default=32, type=int)
parser.add_argument("--adafactor", action="store_true")
class LoggingCallback(pl.Callback):
......
......@@ -30,6 +30,7 @@ logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"label_smoothing": 0.2,
"adafactor": True,
"early_stopping_patience": 2,
"logger_name": "default",
"length_penalty": 0.5,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment