Unverified Commit 95ffbe16 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[Trainer] fix the placement on device with fp16_full_eval (#11322)

* fix the placement on device with fp16_full_eval

* deepspeed never goes on device
parent 3981ce3d
......@@ -336,7 +336,7 @@ class Trainer:
self.place_model_on_device = args.place_model_on_device
if (
self.is_model_parallel
or (args.deepspeed and args.do_train)
or args.deepspeed
or (args.fp16_full_eval and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
):
......@@ -954,8 +954,15 @@ class Trainer:
# memory metrics - must set up as early as possible
self._memory_tracker.start()
args = self.args
self.is_in_train = True
# do_train is not a reliable argument, as it might not be set and .train() still called, so
# the following is a workaround:
if args.fp16_full_eval and not args.do_train:
self.model = self.model.to(args.device)
if "model_path" in kwargs:
resume_from_checkpoint = kwargs.pop("model_path")
warnings.warn(
......@@ -972,7 +979,7 @@ class Trainer:
model_reloaded = False
if self.model_init is not None:
# Seed must be set before instantiating the model when using model_init.
set_seed(self.args.seed)
set_seed(args.seed)
self.model = self.call_model_init(trial)
model_reloaded = True
# Reinitializes optimizer and scheduler
......@@ -980,9 +987,9 @@ class Trainer:
# Load potential model checkpoint
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
resume_from_checkpoint = get_last_checkpoint(self.args.output_dir)
resume_from_checkpoint = get_last_checkpoint(args.output_dir)
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
if resume_from_checkpoint is not None:
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
......@@ -1003,7 +1010,7 @@ class Trainer:
# If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded:
if self.place_model_on_device:
self.model = self.model.to(self.args.device)
self.model = self.model.to(args.device)
self.model_wrapped = self.model
# Keeping track whether we can can len() on the dataset or not
......@@ -1017,24 +1024,24 @@ class Trainer:
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
if train_dataset_is_sized:
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
if self.args.max_steps > 0:
max_steps = self.args.max_steps
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
self.args.max_steps % num_update_steps_per_epoch > 0
if args.max_steps > 0:
max_steps = args.max_steps
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
args.max_steps % num_update_steps_per_epoch > 0
)
else:
max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(self.args.num_train_epochs)
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(args.num_train_epochs)
else:
# see __init__. max_steps is set when the dataset has no __len__
max_steps = self.args.max_steps
num_train_epochs = int(self.args.num_train_epochs)
max_steps = args.max_steps
num_train_epochs = int(args.num_train_epochs)
num_update_steps_per_epoch = max_steps
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
if self.args.deepspeed:
if args.deepspeed:
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
)
......@@ -1068,24 +1075,22 @@ class Trainer:
# Train!
if is_torch_tpu_available():
world_size = xm.xrt_world_size()
elif self.args.local_rank != -1:
elif args.local_rank != -1:
world_size = dist.get_world_size()
else:
world_size = 1
total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps * world_size
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size
num_examples = (
self.num_examples(train_dataloader)
if train_dataset_is_sized
else total_train_batch_size * self.args.max_steps
self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
self.state.epoch = 0
......@@ -1099,16 +1104,16 @@ class Trainer:
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json"))
epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not self.args.ignore_data_skip:
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not self.args.ignore_data_skip:
if not args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
"batches in the first epoch."
......@@ -1129,17 +1134,17 @@ class Trainer:
self.state.is_world_process_zero = self.is_world_process_zero()
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(self.args.device)
tr_loss = torch.tensor(0.0).to(args.device)
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
self._total_flos = self.state.total_flos
model.zero_grad()
self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not self.args.ignore_data_skip:
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
# We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader:
......@@ -1152,23 +1157,19 @@ class Trainer:
train_dataloader.dataset.set_epoch(epoch)
if is_torch_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
self.args.device
)
parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
epoch_iterator = parallel_loader
else:
epoch_iterator = train_dataloader
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
if args.past_index >= 0:
self._past = None
steps_in_epoch = (
len(epoch_iterator)
if train_dataset_is_sized
else self.args.max_steps * self.args.gradient_accumulation_steps
len(epoch_iterator) if train_dataset_is_sized else args.max_steps * args.gradient_accumulation_steps
)
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
for step, inputs in enumerate(epoch_iterator):
......@@ -1177,13 +1178,13 @@ class Trainer:
steps_trained_in_current_epoch -= 1
continue
if step % self.args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
if (
((step + 1) % self.args.gradient_accumulation_steps != 0)
and self.args.local_rank != -1
and self.args._no_sync_in_gradient_accumulation
((step + 1) % args.gradient_accumulation_steps != 0)
and args.local_rank != -1
and args._no_sync_in_gradient_accumulation
):
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
......@@ -1196,13 +1197,13 @@ class Trainer:
if self.deepspeed:
self.deepspeed.step()
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
if (step + 1) % args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= self.args.gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
# Gradient clipping
if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0 and not self.deepspeed:
if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
# deepspeed does its own clipping
if self.use_amp:
......@@ -1211,15 +1212,15 @@ class Trainer:
if hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(self.args.max_grad_norm)
self.optimizer.clip_grad_norm(args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm_(self.args.max_grad_norm)
model.clip_grad_norm_(args.max_grad_norm)
else:
# Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
self.args.max_grad_norm,
args.max_grad_norm,
)
# Optimizer step
......@@ -1243,17 +1244,17 @@ class Trainer:
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
if self.args.tpu_metrics_debug or self.args.debug:
if args.tpu_metrics_debug or args.debug:
if is_torch_tpu_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
......@@ -1265,16 +1266,16 @@ class Trainer:
if self.control.should_training_stop:
break
if self.args.past_index and hasattr(self, "_past"):
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
# Wait for everyone to get here so we are sur the model has been saved by process 0.
if is_torch_tpu_available():
xm.rendezvous("load_best_model_at_end")
elif self.args.local_rank != -1:
elif args.local_rank != -1:
dist.barrier()
logger.info(
......@@ -1283,7 +1284,7 @@ class Trainer:
if isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
if self.place_model_on_device:
self.model = self.model.to(self.args.device)
self.model = self.model.to(args.device)
else:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
......@@ -1299,7 +1300,7 @@ class Trainer:
metrics["total_flos"] = self.state.total_flos
self.log(metrics)
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
......@@ -1952,7 +1953,7 @@ class Trainer:
model = self._wrap_model(self.model, training=False)
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
# ``train`` is running, half it first and then put on device
# ``train`` is running, halve it first and then put on device
if not self.is_in_train and self.args.fp16_full_eval:
model = model.half().to(self.args.device)
......@@ -2288,7 +2289,7 @@ class Trainer:
model = self._wrap_model(self.model, training=False)
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
# ``train`` is running, half it first and then put on device
# ``train`` is running, halve it first and then put on device
if not self.is_in_train and self.args.fp16_full_eval:
model = model.half().to(self.args.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment