Unverified Commit be0e189b authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Revert frozen training arguments (#25903)

* Revert frozen training arguments

* TODO
parent 69c5b8f1
...@@ -163,15 +163,6 @@ class CustomTrainingArguments(TrainingArguments): ...@@ -163,15 +163,6 @@ class CustomTrainingArguments(TrainingArguments):
default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."} default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."}
) )
def __post_init__(self):
# Compute absolute learning rate while args are mutable
super().__post_init__()
if self.base_learning_rate is not None:
total_train_batch_size = self.train_batch_size * self.gradient_accumulation_steps * self.world_size
delattr(self, "_frozen")
self.learning_rate = self.base_learning_rate * total_train_batch_size / 256
setattr(self, "_frozen", True)
def collate_fn(examples): def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = torch.stack([example["pixel_values"] for example in examples])
...@@ -362,6 +353,13 @@ def main(): ...@@ -362,6 +353,13 @@ def main():
# Set the validation transforms # Set the validation transforms
ds["validation"].set_transform(preprocess_images) ds["validation"].set_transform(preprocess_images)
# Compute absolute learning rate
total_train_batch_size = (
training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
if training_args.base_learning_rate is not None:
training_args.learning_rate = training_args.base_learning_rate * total_train_batch_size / 256
# Initialize our trainer # Initialize our trainer
trainer = Trainer( trainer = Trainer(
model=model, model=model,
......
...@@ -18,7 +18,6 @@ Fine-tuning the library models for sequence to sequence. ...@@ -18,7 +18,6 @@ Fine-tuning the library models for sequence to sequence.
""" """
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import dataclasses
import logging import logging
import os import os
import sys import sys
...@@ -675,10 +674,14 @@ def main(): ...@@ -675,10 +674,14 @@ def main():
return result return result
# Override the decoding parameters of Seq2SeqTrainer # Override the decoding parameters of Seq2SeqTrainer
if training_args.generation_max_length is None: training_args.generation_max_length = (
training_args = dataclasses.replace(training_args, generation_max_length=data_args.val_max_target_length) training_args.generation_max_length
if training_args.generation_num_beams is None: if training_args.generation_max_length is not None
training_args = dataclasses.replace(training_args, generation_num_beams=data_args.num_beams) else data_args.val_max_target_length
)
training_args.generation_num_beams = (
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
)
# Initialize our Trainer # Initialize our Trainer
trainer = Seq2SeqTrainer( trainer = Seq2SeqTrainer(
......
...@@ -21,7 +21,6 @@ https://huggingface.co/models?filter=fill-mask ...@@ -21,7 +21,6 @@ https://huggingface.co/models?filter=fill-mask
""" """
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
import dataclasses
import json import json
import logging import logging
import math import math
...@@ -367,7 +366,7 @@ def main(): ...@@ -367,7 +366,7 @@ def main():
# If we have ref files, need to avoid it removed by trainer # If we have ref files, need to avoid it removed by trainer
has_ref = data_args.train_ref_file or data_args.validation_ref_file has_ref = data_args.train_ref_file or data_args.validation_ref_file
if has_ref: if has_ref:
training_args = dataclasses.replace(training_args, remove_unused_columns=False) training_args.remove_unused_columns = False
# Data collator # Data collator
# This one will take care of randomly masking the tokens. # This one will take care of randomly masking the tokens.
......
...@@ -259,6 +259,7 @@ def main(): ...@@ -259,6 +259,7 @@ def main():
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if training_args.output_dir is not None: if training_args.output_dir is not None:
training_args.output_dir = Path(training_args.output_dir)
os.makedirs(training_args.output_dir, exist_ok=True) os.makedirs(training_args.output_dir, exist_ok=True)
# endregion # endregion
...@@ -266,8 +267,8 @@ def main(): ...@@ -266,8 +267,8 @@ def main():
# Detecting last checkpoint. # Detecting last checkpoint.
checkpoint = None checkpoint = None
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir: if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
config_path = Path(training_args.output_dir) / CONFIG_NAME config_path = training_args.output_dir / CONFIG_NAME
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
if config_path.is_file() and weights_path.is_file(): if config_path.is_file() and weights_path.is_file():
checkpoint = training_args.output_dir checkpoint = training_args.output_dir
logger.info( logger.info(
......
...@@ -265,6 +265,7 @@ def main(): ...@@ -265,6 +265,7 @@ def main():
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if training_args.output_dir is not None: if training_args.output_dir is not None:
training_args.output_dir = Path(training_args.output_dir)
os.makedirs(training_args.output_dir, exist_ok=True) os.makedirs(training_args.output_dir, exist_ok=True)
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length: if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length:
...@@ -276,8 +277,8 @@ def main(): ...@@ -276,8 +277,8 @@ def main():
# Detecting last checkpoint. # Detecting last checkpoint.
checkpoint = None checkpoint = None
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir: if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
config_path = Path(training_args.output_dir) / CONFIG_NAME config_path = training_args.output_dir / CONFIG_NAME
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
if config_path.is_file() and weights_path.is_file(): if config_path.is_file() and weights_path.is_file():
checkpoint = training_args.output_dir checkpoint = training_args.output_dir
logger.warning( logger.warning(
......
...@@ -1172,8 +1172,6 @@ class Trainer: ...@@ -1172,8 +1172,6 @@ class Trainer:
elif self.hp_search_backend == HPSearchBackend.WANDB: elif self.hp_search_backend == HPSearchBackend.WANDB:
params = trial params = trial
# Unfreeze args for hyperparameter search
delattr(self.args, "_frozen")
for key, value in params.items(): for key, value in params.items():
if not hasattr(self.args, key): if not hasattr(self.args, key):
logger.warning( logger.warning(
...@@ -1205,8 +1203,6 @@ class Trainer: ...@@ -1205,8 +1203,6 @@ class Trainer:
self.args.hf_deepspeed_config.trainer_config_process(self.args) self.args.hf_deepspeed_config.trainer_config_process(self.args)
self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config) self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
# Re-freeze them
setattr(self.args, "_frozen", True)
self.create_accelerator_and_postprocess() self.create_accelerator_and_postprocess()
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
......
...@@ -18,7 +18,7 @@ import json ...@@ -18,7 +18,7 @@ import json
import math import math
import os import os
import warnings import warnings
from dataclasses import FrozenInstanceError, asdict, dataclass, field, fields from dataclasses import asdict, dataclass, field, fields
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
...@@ -154,6 +154,7 @@ class OptimizerNames(ExplicitEnum): ...@@ -154,6 +154,7 @@ class OptimizerNames(ExplicitEnum):
PAGED_LION_8BIT = "paged_lion_8bit" PAGED_LION_8BIT = "paged_lion_8bit"
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
@dataclass @dataclass
class TrainingArguments: class TrainingArguments:
""" """
...@@ -1707,16 +1708,6 @@ class TrainingArguments: ...@@ -1707,16 +1708,6 @@ class TrainingArguments:
FutureWarning, FutureWarning,
) )
# Finally set the `TrainingArguments` to be immutable
self._frozen = True
def __setattr__(self, name, value):
# Once fully through the `__post_init__`, `TrainingArguments` are immutable
if not name.startswith("_") and getattr(self, "_frozen", False):
raise FrozenInstanceError(f"cannot assign to field {name}")
else:
super().__setattr__(name, value)
def __str__(self): def __str__(self):
self_as_dict = asdict(self) self_as_dict = asdict(self)
......
...@@ -139,9 +139,9 @@ class RegressionTrainingArguments(TrainingArguments): ...@@ -139,9 +139,9 @@ class RegressionTrainingArguments(TrainingArguments):
b: float = 0.0 b: float = 0.0
def __post_init__(self): def __post_init__(self):
super().__post_init__()
# save resources not dealing with reporting (also avoids the warning when it's not set) # save resources not dealing with reporting (also avoids the warning when it's not set)
self.report_to = [] self.report_to = []
super().__post_init__()
class RepeatDataset: class RepeatDataset:
...@@ -529,8 +529,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -529,8 +529,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.check_trained_model(trainer.model) self.check_trained_model(trainer.model)
# Re-training should restart from scratch, thus lead the same results and new seed should be used. # Re-training should restart from scratch, thus lead the same results and new seed should be used.
args = TrainingArguments("./regression", learning_rate=0.1, seed=314) trainer.args.seed = 314
trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
trainer.train() trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True) self.check_trained_model(trainer.model, alternate_seed=True)
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import dataclasses
from typing import Dict from typing import Dict
import numpy as np import numpy as np
...@@ -206,14 +205,7 @@ if __name__ == "__main__": ...@@ -206,14 +205,7 @@ if __name__ == "__main__":
logger.error(p.metrics) logger.error(p.metrics)
exit(1) exit(1)
training_args = dataclasses.replace(training_args, eval_accumulation_steps=2) trainer.args.eval_accumulation_steps = 2
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
metrics = trainer.evaluate() metrics = trainer.evaluate()
logger.info(metrics) logger.info(metrics)
...@@ -227,22 +219,15 @@ if __name__ == "__main__": ...@@ -227,22 +219,15 @@ if __name__ == "__main__":
logger.error(p.metrics) logger.error(p.metrics)
exit(1) exit(1)
training_args = dataclasses.replace(training_args, eval_accumulation_steps=None) trainer.args.eval_accumulation_steps = None
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
# Check that `dispatch_batches=False` will work on a finite iterable dataset # Check that `dispatch_batches=False` will work on a finite iterable dataset
train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1) train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)
model = RegressionModel() model = RegressionModel()
training_args = dataclasses.replace( training_args.per_device_train_batch_size = 1
training_args, per_device_train_batch_size=1, max_steps=1, dispatch_batches=False training_args.max_steps = 1
) training_args.dispatch_batches = False
trainer = Trainer(model, training_args, train_dataset=train_dataset) trainer = Trainer(model, training_args, train_dataset=train_dataset)
trainer.train() trainer.train()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment