Unverified Commit fdccf82e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Remove config assumption in Trainer (#7464)

* Remove config assumption in Trainer

* Initialize for eval
parent cc4eff80
......@@ -282,7 +282,7 @@ class Trainer:
# Create output directory if needed
if self.is_world_process_zero():
os.makedirs(self.args.output_dir, exist_ok=True)
if is_torch_tpu_available():
if is_torch_tpu_available() and isinstance(self.model, PreTrainedModel):
# Set an xla_device flag on the model's config.
# We'll find a more elegant and not need to do this in the future.
self.model.config.xla_device = True
......@@ -490,11 +490,9 @@ class Trainer:
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
try:
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
except AttributeError:
# in case the model has no config
combined_dict = {**self.args.to_sanitized_dict()}
if isinstance(self.model, PreTrainedModel):
combined_dict = {**self.model.config.to_dict(), **combined_dict}
wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
)
......@@ -533,6 +531,7 @@ class Trainer:
if experiment is not None:
experiment._set_model_graph(self.model, framework="transformers")
experiment._log_parameters(self.args, prefix="args/", framework="transformers")
if isinstance(self.model, PreTrainedModel):
experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")
def num_examples(self, dataloader: DataLoader) -> int:
......@@ -679,7 +678,11 @@ class Trainer:
model,
device_ids=[self.args.local_rank],
output_device=self.args.local_rank,
find_unused_parameters=not getattr(model.config, "gradient_checkpointing", False),
find_unused_parameters=(
not getattr(model.config, "gradient_checkpointing", False)
if isinstance(model, PreTrainedModel)
else True
),
)
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
......@@ -707,15 +710,14 @@ class Trainer:
self.global_step = 0
self.epoch = 0
self.total_flos = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if model_path is not None:
# set global_step to global_step of last saved checkpoint from model path
try:
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
self.total_flos = getattr(self._actual_model(model).config, "total_flos", 0)
epochs_trained = self.global_step // num_update_steps_per_epoch
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
......@@ -723,14 +725,13 @@ class Trainer:
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", self.global_step)
logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError:
self.global_step = 0
self.total_flos = 0
logger.info(" Starting fine-tuning.")
tr_loss = torch.tensor(0.0).to(self.args.device)
self.total_flos = self.state.total_flos
logging_loss_scalar = 0.0
model.zero_grad()
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
......@@ -1029,7 +1030,7 @@ class Trainer:
else:
total_flos = self.total_flos
if total_flos > 0:
logs["total_flos"] = self.total_flos
logs["total_flos"] = total_flos
if self.global_step is None:
# when logging evaluation metrics without training
self.global_step = 0
......@@ -1245,11 +1246,9 @@ class Trainer:
# Storing the number of floating-point operations that went into the model
if self.total_flos is not None:
if self.args.local_rank != -1:
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
self.state.total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
else:
total_flos = self.total_flos
if total_flos > 0:
self.model.config.total_flos = total_flos
self.state.total_flos = self.total_flos
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
ordering_and_checkpoint_path = []
......@@ -1363,13 +1362,6 @@ class Trainer:
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
)
assert not getattr(
self.model.config, "output_attentions", False
), "The prediction loop does not work with `output_attentions=True`."
assert not getattr(
self.model.config, "output_hidden_states", False
), "The prediction loop does not work with `output_hidden_states=True`."
model = self.model
# multi-gpu eval
if self.args.n_gpu > 1:
......
......@@ -224,6 +224,7 @@ class TrainerState:
A class containing the `Trainer` fields that will be saved along the model and optimizer.
"""
total_flos: int = 0
best_metric: Optional[float] = None
best_model_checkpoint: Optional[str] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment