Unverified Commit a01b033c authored by Anton Vlasjuk's avatar Anton Vlasjuk Committed by GitHub
Browse files

Fix galore lr display with schedulers (#31710)

* fix galore lr display with lr schedulers

* style

* add some tests to check for displayed lrs

* copy-paste err for warmup steps

* standardize the default lr to be only in the optimizer

* trying out my luck with the reads
parent ac262604
......@@ -519,7 +519,7 @@ def get_scheduler(
if param.requires_grad:
param.register_post_accumulate_grad_hook(scheduler_hook)
return LayerWiseDummyScheduler()
return LayerWiseDummyScheduler(optimizer_dict=optimizer_dict, lr=optimizer.defaults["lr"])
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
......
......@@ -27,6 +27,7 @@ import warnings
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from itertools import chain
from logging import StreamHandler
from typing import Any, Dict, Iterator, List, Optional, Union
......@@ -1379,13 +1380,24 @@ class LayerWiseDummyScheduler(LRScheduler):
"""
def __init__(self, *args, **kwargs):
optimizer = LayerWiseDummyOptimizer()
self.default_lr = kwargs["lr"]
optimizer = LayerWiseDummyOptimizer(**kwargs)
last_epoch = -1
verbose = False
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
return [group["lr"] for group in self.optimizer.param_groups]
# default value
lrs = [self.default_lr]
# we take each lr in the parameters if they exist, assumes the optimizer to be the `LayerWiseDummyOptimizer`
if self.optimizer is not None:
param_wise_lrs = [
[group["lr"] for group in optim.param_groups] for optim in self.optimizer.optimizer_dict.values()
]
lrs = list(chain(*param_wise_lrs))
return lrs
def _get_closed_form_lr(self):
return self.base_lrs
......@@ -1653,6 +1653,84 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(galore_peak_memory < upper_bound_pm)
self.assertTrue(lower_bound_pm < galore_peak_memory)
@require_galore_torch
@require_torch_gpu
def test_galore_lr_display_without_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
learning_rate = 1e-9
num_steps = 10
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=learning_rate,
logging_steps=5,
optim="galore_adamw",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
# reflects displayed lr in trainer
self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate])
@require_galore_torch
@require_torch_gpu
def test_galore_lr_display_with_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
learning_rate = 2e-4
num_train_epochs = 2
num_warmup_steps = 5
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
num_train_epochs=num_train_epochs,
learning_rate=learning_rate,
warmup_steps=num_warmup_steps,
lr_scheduler_type="cosine",
logging_steps=1,
optim="galore_adamw",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# creating log history of trainer, results don't matter
trainer.train()
logs = trainer.state.log_history[1:][:-1]
# reach given learning rate peak and end with 0 lr
self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate)
self.assertTrue(logs[-1]["learning_rate"] == 0)
# increasing and decreasing pattern of lrs
increasing_lrs = [
logs[i]["learning_rate"] < logs[i + 1]["learning_rate"]
for i in range(len(logs))
if i < num_warmup_steps - 2
]
decreasing_lrs = [
logs[i]["learning_rate"] > logs[i + 1]["learning_rate"]
for i in range(len(logs) - 1)
if i >= num_warmup_steps - 2
]
self.assertTrue(all(increasing_lrs))
self.assertTrue(all(decreasing_lrs))
# warm up steps << total steps
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
@require_torch_multi_accelerator
def test_data_is_not_parallelized_when_model_is_parallel(self):
model = RegressionModel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment