"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5603fad2479ad22ca4689f6a4dbf56ef2f1f0973"
Unverified Commit ca514992 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Make training args fully immutable (#25435)

* Make training args fully immutable

* Working tests, PyTorch

* In test_trainer

* during testing

* Use proper dataclass way

* Fix test

* Another one

* Fix tf

* Lingering slow

* Exception

* Clean
parent f11518a5
......@@ -163,6 +163,15 @@ class CustomTrainingArguments(TrainingArguments):
default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."}
)
def __post_init__(self):
# Compute absolute learning rate while args are mutable
super().__post_init__()
if self.base_learning_rate is not None:
total_train_batch_size = self.train_batch_size * self.gradient_accumulation_steps * self.world_size
delattr(self, "_frozen")
self.learning_rate = self.base_learning_rate * total_train_batch_size / 256
setattr(self, "_frozen", True)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
......@@ -353,13 +362,6 @@ def main():
# Set the validation transforms
ds["validation"].set_transform(preprocess_images)
# Compute absolute learning rate
total_train_batch_size = (
training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
if training_args.base_learning_rate is not None:
training_args.learning_rate = training_args.base_learning_rate * total_train_batch_size / 256
# Initialize our trainer
trainer = Trainer(
model=model,
......
......@@ -18,6 +18,7 @@ Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import dataclasses
import logging
import os
import sys
......@@ -674,14 +675,10 @@ def main():
return result
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
training_args.generation_num_beams = (
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
)
if training_args.generation_max_length is None:
training_args = dataclasses.replace(training_args, generation_max_length=data_args.val_max_target_length)
if training_args.generation_num_beams is None:
training_args = dataclasses.replace(training_args, generation_num_beams=data_args.num_beams)
# Initialize our Trainer
trainer = Seq2SeqTrainer(
......
......@@ -21,6 +21,7 @@ https://huggingface.co/models?filter=fill-mask
"""
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
import dataclasses
import json
import logging
import math
......@@ -366,7 +367,7 @@ def main():
# If we have ref files, need to avoid it removed by trainer
has_ref = data_args.train_ref_file or data_args.validation_ref_file
if has_ref:
training_args.remove_unused_columns = False
training_args = dataclasses.replace(training_args, remove_unused_columns=False)
# Data collator
# This one will take care of randomly masking the tokens.
......
......@@ -259,7 +259,6 @@ def main():
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if training_args.output_dir is not None:
training_args.output_dir = Path(training_args.output_dir)
os.makedirs(training_args.output_dir, exist_ok=True)
# endregion
......@@ -267,8 +266,8 @@ def main():
# Detecting last checkpoint.
checkpoint = None
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
config_path = training_args.output_dir / CONFIG_NAME
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
config_path = Path(training_args.output_dir) / CONFIG_NAME
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
if config_path.is_file() and weights_path.is_file():
checkpoint = training_args.output_dir
logger.info(
......
......@@ -265,7 +265,6 @@ def main():
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if training_args.output_dir is not None:
training_args.output_dir = Path(training_args.output_dir)
os.makedirs(training_args.output_dir, exist_ok=True)
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length:
......@@ -277,8 +276,8 @@ def main():
# Detecting last checkpoint.
checkpoint = None
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
config_path = training_args.output_dir / CONFIG_NAME
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
config_path = Path(training_args.output_dir) / CONFIG_NAME
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
if config_path.is_file() and weights_path.is_file():
checkpoint = training_args.output_dir
logger.warning(
......
......@@ -18,7 +18,7 @@ import json
import math
import os
import warnings
from dataclasses import asdict, dataclass, field, fields
from dataclasses import FrozenInstanceError, asdict, dataclass, field, fields
from datetime import timedelta
from enum import Enum
from pathlib import Path
......@@ -1687,6 +1687,16 @@ class TrainingArguments:
mixed_precision_dtype = "bf16"
os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
# Finally set the `TrainingArguments` to be immutable
self._frozen = True
def __setattr__(self, name, value):
# Once fully through the `__post_init__`, `TrainingArguments` are immutable
if not name.startswith("_") and getattr(self, "_frozen", False):
raise FrozenInstanceError(f"cannot assign to field {name}")
else:
super().__setattr__(name, value)
def __str__(self):
self_as_dict = asdict(self)
......
......@@ -139,9 +139,9 @@ class RegressionTrainingArguments(TrainingArguments):
b: float = 0.0
def __post_init__(self):
super().__post_init__()
# save resources not dealing with reporting (also avoids the warning when it's not set)
self.report_to = []
super().__post_init__()
class RepeatDataset:
......@@ -529,7 +529,8 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.check_trained_model(trainer.model)
# Re-training should restart from scratch, thus lead the same results and new seed should be used.
trainer.args.seed = 314
args = TrainingArguments("./regression", learning_rate=0.1, seed=314)
trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True)
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from typing import Dict
import numpy as np
......@@ -205,7 +206,14 @@ if __name__ == "__main__":
logger.error(p.metrics)
exit(1)
trainer.args.eval_accumulation_steps = 2
training_args = dataclasses.replace(training_args, eval_accumulation_steps=2)
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
metrics = trainer.evaluate()
logger.info(metrics)
......@@ -219,15 +227,22 @@ if __name__ == "__main__":
logger.error(p.metrics)
exit(1)
trainer.args.eval_accumulation_steps = None
training_args = dataclasses.replace(training_args, eval_accumulation_steps=None)
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
# Check that `dispatch_batches=False` will work on a finite iterable dataset
train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)
model = RegressionModel()
training_args.per_device_train_batch_size = 1
training_args.max_steps = 1
training_args.dispatch_batches = False
training_args = dataclasses.replace(
training_args, per_device_train_batch_size=1, max_steps=1, dispatch_batches=False
)
trainer = Trainer(model, training_args, train_dataset=train_dataset)
trainer.train()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment