"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "edb6b369d76ff0a132f9e18c69aa06aebfe87fe0"
Unverified Commit fdccf82e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Remove config assumption in Trainer (#7464)

* Remove config assumption in Trainer

* Initialize for eval
parent cc4eff80
...@@ -282,7 +282,7 @@ class Trainer: ...@@ -282,7 +282,7 @@ class Trainer:
# Create output directory if needed # Create output directory if needed
if self.is_world_process_zero(): if self.is_world_process_zero():
os.makedirs(self.args.output_dir, exist_ok=True) os.makedirs(self.args.output_dir, exist_ok=True)
if is_torch_tpu_available(): if is_torch_tpu_available() and isinstance(self.model, PreTrainedModel):
# Set an xla_device flag on the model's config. # Set an xla_device flag on the model's config.
# We'll find a more elegant and not need to do this in the future. # We'll find a more elegant and not need to do this in the future.
self.model.config.xla_device = True self.model.config.xla_device = True
...@@ -490,11 +490,9 @@ class Trainer: ...@@ -490,11 +490,9 @@ class Trainer:
logger.info( logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
) )
try: combined_dict = {**self.args.to_sanitized_dict()}
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()} if isinstance(self.model, PreTrainedModel):
except AttributeError: combined_dict = {**self.model.config.to_dict(), **combined_dict}
# in case the model has no config
combined_dict = {**self.args.to_sanitized_dict()}
wandb.init( wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
) )
...@@ -533,7 +531,8 @@ class Trainer: ...@@ -533,7 +531,8 @@ class Trainer:
if experiment is not None: if experiment is not None:
experiment._set_model_graph(self.model, framework="transformers") experiment._set_model_graph(self.model, framework="transformers")
experiment._log_parameters(self.args, prefix="args/", framework="transformers") experiment._log_parameters(self.args, prefix="args/", framework="transformers")
experiment._log_parameters(self.model.config, prefix="config/", framework="transformers") if isinstance(self.model, PreTrainedModel):
experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")
def num_examples(self, dataloader: DataLoader) -> int: def num_examples(self, dataloader: DataLoader) -> int:
""" """
...@@ -679,7 +678,11 @@ class Trainer: ...@@ -679,7 +678,11 @@ class Trainer:
model, model,
device_ids=[self.args.local_rank], device_ids=[self.args.local_rank],
output_device=self.args.local_rank, output_device=self.args.local_rank,
find_unused_parameters=not getattr(model.config, "gradient_checkpointing", False), find_unused_parameters=(
not getattr(model.config, "gradient_checkpointing", False)
if isinstance(model, PreTrainedModel)
else True
),
) )
# find_unused_parameters breaks checkpointing as per # find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
...@@ -707,15 +710,14 @@ class Trainer: ...@@ -707,15 +710,14 @@ class Trainer:
self.global_step = 0 self.global_step = 0
self.epoch = 0 self.epoch = 0
self.total_flos = 0
epochs_trained = 0 epochs_trained = 0
steps_trained_in_current_epoch = 0 steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint # Check if continuing training from a checkpoint
if model_path is not None: if model_path is not None:
# set global_step to global_step of last saved checkpoint from model path # set global_step to global_step of last saved checkpoint from model path
try: try:
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0]) self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
self.total_flos = getattr(self._actual_model(model).config, "total_flos", 0)
epochs_trained = self.global_step // num_update_steps_per_epoch epochs_trained = self.global_step // num_update_steps_per_epoch
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch) steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
...@@ -723,14 +725,13 @@ class Trainer: ...@@ -723,14 +725,13 @@ class Trainer:
logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", self.global_step) logger.info(" Continuing training from global step %d", self.global_step)
logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError: except ValueError:
self.global_step = 0 self.global_step = 0
self.total_flos = 0
logger.info(" Starting fine-tuning.") logger.info(" Starting fine-tuning.")
tr_loss = torch.tensor(0.0).to(self.args.device) tr_loss = torch.tensor(0.0).to(self.args.device)
self.total_flos = self.state.total_flos
logging_loss_scalar = 0.0 logging_loss_scalar = 0.0
model.zero_grad() model.zero_grad()
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero() disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
...@@ -1029,7 +1030,7 @@ class Trainer: ...@@ -1029,7 +1030,7 @@ class Trainer:
else: else:
total_flos = self.total_flos total_flos = self.total_flos
if total_flos > 0: if total_flos > 0:
logs["total_flos"] = self.total_flos logs["total_flos"] = total_flos
if self.global_step is None: if self.global_step is None:
# when logging evaluation metrics without training # when logging evaluation metrics without training
self.global_step = 0 self.global_step = 0
...@@ -1245,11 +1246,9 @@ class Trainer: ...@@ -1245,11 +1246,9 @@ class Trainer:
# Storing the number of floating-point operations that went into the model # Storing the number of floating-point operations that went into the model
if self.total_flos is not None: if self.total_flos is not None:
if self.args.local_rank != -1: if self.args.local_rank != -1:
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item() self.state.total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
else: else:
total_flos = self.total_flos self.state.total_flos = self.total_flos
if total_flos > 0:
self.model.config.total_flos = total_flos
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]: def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
ordering_and_checkpoint_path = [] ordering_and_checkpoint_path = []
...@@ -1363,13 +1362,6 @@ class Trainer: ...@@ -1363,13 +1362,6 @@ class Trainer:
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
) )
assert not getattr(
self.model.config, "output_attentions", False
), "The prediction loop does not work with `output_attentions=True`."
assert not getattr(
self.model.config, "output_hidden_states", False
), "The prediction loop does not work with `output_hidden_states=True`."
model = self.model model = self.model
# multi-gpu eval # multi-gpu eval
if self.args.n_gpu > 1: if self.args.n_gpu > 1:
......
...@@ -224,6 +224,7 @@ class TrainerState: ...@@ -224,6 +224,7 @@ class TrainerState:
A class containing the `Trainer` fields that will be saved along the model and optimizer. A class containing the `Trainer` fields that will be saved along the model and optimizer.
""" """
total_flos: int = 0
best_metric: Optional[float] = None best_metric: Optional[float] = None
best_model_checkpoint: Optional[str] = None best_model_checkpoint: Optional[str] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment