Unverified Commit 95ffbe16 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[Trainer] fix the placement on device with fp16_full_eval (#11322)

* fix the placement on device with fp16_full_eval

* deepspeed never goes on device
parent 3981ce3d
...@@ -336,7 +336,7 @@ class Trainer: ...@@ -336,7 +336,7 @@ class Trainer:
self.place_model_on_device = args.place_model_on_device self.place_model_on_device = args.place_model_on_device
if ( if (
self.is_model_parallel self.is_model_parallel
or (args.deepspeed and args.do_train) or args.deepspeed
or (args.fp16_full_eval and not args.do_train) or (args.fp16_full_eval and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
): ):
...@@ -954,8 +954,15 @@ class Trainer: ...@@ -954,8 +954,15 @@ class Trainer:
# memory metrics - must set up as early as possible # memory metrics - must set up as early as possible
self._memory_tracker.start() self._memory_tracker.start()
args = self.args
self.is_in_train = True self.is_in_train = True
# do_train is not a reliable argument, as it might not be set and .train() still called, so
# the following is a workaround:
if args.fp16_full_eval and not args.do_train:
self.model = self.model.to(args.device)
if "model_path" in kwargs: if "model_path" in kwargs:
resume_from_checkpoint = kwargs.pop("model_path") resume_from_checkpoint = kwargs.pop("model_path")
warnings.warn( warnings.warn(
...@@ -972,7 +979,7 @@ class Trainer: ...@@ -972,7 +979,7 @@ class Trainer:
model_reloaded = False model_reloaded = False
if self.model_init is not None: if self.model_init is not None:
# Seed must be set before instantiating the model when using model_init. # Seed must be set before instantiating the model when using model_init.
set_seed(self.args.seed) set_seed(args.seed)
self.model = self.call_model_init(trial) self.model = self.call_model_init(trial)
model_reloaded = True model_reloaded = True
# Reinitializes optimizer and scheduler # Reinitializes optimizer and scheduler
...@@ -980,9 +987,9 @@ class Trainer: ...@@ -980,9 +987,9 @@ class Trainer:
# Load potential model checkpoint # Load potential model checkpoint
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
resume_from_checkpoint = get_last_checkpoint(self.args.output_dir) resume_from_checkpoint = get_last_checkpoint(args.output_dir)
if resume_from_checkpoint is None: if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})") raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
if resume_from_checkpoint is not None: if resume_from_checkpoint is not None:
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
...@@ -1003,7 +1010,7 @@ class Trainer: ...@@ -1003,7 +1010,7 @@ class Trainer:
# If model was re-initialized, put it on the right device and update self.model_wrapped # If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded: if model_reloaded:
if self.place_model_on_device: if self.place_model_on_device:
self.model = self.model.to(self.args.device) self.model = self.model.to(args.device)
self.model_wrapped = self.model self.model_wrapped = self.model
# Keeping track whether we can can len() on the dataset or not # Keeping track whether we can can len() on the dataset or not
...@@ -1017,24 +1024,24 @@ class Trainer: ...@@ -1017,24 +1024,24 @@ class Trainer:
# number of training steps per epoch: num_update_steps_per_epoch # number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps # total number of training steps to execute: max_steps
if train_dataset_is_sized: if train_dataset_is_sized:
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
if self.args.max_steps > 0: if args.max_steps > 0:
max_steps = self.args.max_steps max_steps = args.max_steps
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int( num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
self.args.max_steps % num_update_steps_per_epoch > 0 args.max_steps % num_update_steps_per_epoch > 0
) )
else: else:
max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch) max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(self.args.num_train_epochs) num_train_epochs = math.ceil(args.num_train_epochs)
else: else:
# see __init__. max_steps is set when the dataset has no __len__ # see __init__. max_steps is set when the dataset has no __len__
max_steps = self.args.max_steps max_steps = args.max_steps
num_train_epochs = int(self.args.num_train_epochs) num_train_epochs = int(args.num_train_epochs)
num_update_steps_per_epoch = max_steps num_update_steps_per_epoch = max_steps
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
if self.args.deepspeed: if args.deepspeed:
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
) )
...@@ -1068,24 +1075,22 @@ class Trainer: ...@@ -1068,24 +1075,22 @@ class Trainer:
# Train! # Train!
if is_torch_tpu_available(): if is_torch_tpu_available():
world_size = xm.xrt_world_size() world_size = xm.xrt_world_size()
elif self.args.local_rank != -1: elif args.local_rank != -1:
world_size = dist.get_world_size() world_size = dist.get_world_size()
else: else:
world_size = 1 world_size = 1
total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps * world_size total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size
num_examples = ( num_examples = (
self.num_examples(train_dataloader) self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps
if train_dataset_is_sized
else total_train_batch_size * self.args.max_steps
) )
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}") logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}") logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}") logger.info(f" Total optimization steps = {max_steps}")
self.state.epoch = 0 self.state.epoch = 0
...@@ -1099,16 +1104,16 @@ class Trainer: ...@@ -1099,16 +1104,16 @@ class Trainer:
): ):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json")) self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json"))
epochs_trained = self.state.global_step // num_update_steps_per_epoch epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not self.args.ignore_data_skip: if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps steps_trained_in_current_epoch *= args.gradient_accumulation_steps
else: else:
steps_trained_in_current_epoch = 0 steps_trained_in_current_epoch = 0
logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}") logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}") logger.info(f" Continuing training from global step {self.state.global_step}")
if not self.args.ignore_data_skip: if not args.ignore_data_skip:
logger.info( logger.info(
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
"batches in the first epoch." "batches in the first epoch."
...@@ -1129,17 +1134,17 @@ class Trainer: ...@@ -1129,17 +1134,17 @@ class Trainer:
self.state.is_world_process_zero = self.is_world_process_zero() self.state.is_world_process_zero = self.is_world_process_zero()
# tr_loss is a tensor to avoid synchronization of TPUs through .item() # tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(self.args.device) tr_loss = torch.tensor(0.0).to(args.device)
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0 self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step self._globalstep_last_logged = self.state.global_step
self._total_flos = self.state.total_flos self._total_flos = self.state.total_flos
model.zero_grad() model.zero_grad()
self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control) self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not self.args.ignore_data_skip: if not args.ignore_data_skip:
for epoch in range(epochs_trained): for epoch in range(epochs_trained):
# We just need to begin an iteration to create the randomization of the sampler. # We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader: for _ in train_dataloader:
...@@ -1152,23 +1157,19 @@ class Trainer: ...@@ -1152,23 +1157,19 @@ class Trainer:
train_dataloader.dataset.set_epoch(epoch) train_dataloader.dataset.set_epoch(epoch)
if is_torch_tpu_available(): if is_torch_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
self.args.device
)
epoch_iterator = parallel_loader epoch_iterator = parallel_loader
else: else:
epoch_iterator = train_dataloader epoch_iterator = train_dataloader
# Reset the past mems state at the beginning of each epoch if necessary. # Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0: if args.past_index >= 0:
self._past = None self._past = None
steps_in_epoch = ( steps_in_epoch = (
len(epoch_iterator) len(epoch_iterator) if train_dataset_is_sized else args.max_steps * args.gradient_accumulation_steps
if train_dataset_is_sized
else self.args.max_steps * self.args.gradient_accumulation_steps
) )
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control) self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
for step, inputs in enumerate(epoch_iterator): for step, inputs in enumerate(epoch_iterator):
...@@ -1177,13 +1178,13 @@ class Trainer: ...@@ -1177,13 +1178,13 @@ class Trainer:
steps_trained_in_current_epoch -= 1 steps_trained_in_current_epoch -= 1
continue continue
if step % self.args.gradient_accumulation_steps == 0: if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control) self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
if ( if (
((step + 1) % self.args.gradient_accumulation_steps != 0) ((step + 1) % args.gradient_accumulation_steps != 0)
and self.args.local_rank != -1 and args.local_rank != -1
and self.args._no_sync_in_gradient_accumulation and args._no_sync_in_gradient_accumulation
): ):
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example. # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync(): with model.no_sync():
...@@ -1196,13 +1197,13 @@ class Trainer: ...@@ -1196,13 +1197,13 @@ class Trainer:
if self.deepspeed: if self.deepspeed:
self.deepspeed.step() self.deepspeed.step()
if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( if (step + 1) % args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps # last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= self.args.gradient_accumulation_steps steps_in_epoch <= args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch and (step + 1) == steps_in_epoch
): ):
# Gradient clipping # Gradient clipping
if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0 and not self.deepspeed: if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
# deepspeed does its own clipping # deepspeed does its own clipping
if self.use_amp: if self.use_amp:
...@@ -1211,15 +1212,15 @@ class Trainer: ...@@ -1211,15 +1212,15 @@ class Trainer:
if hasattr(self.optimizer, "clip_grad_norm"): if hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(self.args.max_grad_norm) self.optimizer.clip_grad_norm(args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"): elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm_(self.args.max_grad_norm) model.clip_grad_norm_(args.max_grad_norm)
else: else:
# Revert to normal clipping otherwise, handling Apex or full precision # Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_( torch.nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer) if self.use_apex else model.parameters(), amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
self.args.max_grad_norm, args.max_grad_norm,
) )
# Optimizer step # Optimizer step
...@@ -1243,17 +1244,17 @@ class Trainer: ...@@ -1243,17 +1244,17 @@ class Trainer:
model.zero_grad() model.zero_grad()
self.state.global_step += 1 self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control) self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
if self.control.should_epoch_stop or self.control.should_training_stop: if self.control.should_epoch_stop or self.control.should_training_stop:
break break
self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control) self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
if self.args.tpu_metrics_debug or self.args.debug: if args.tpu_metrics_debug or args.debug:
if is_torch_tpu_available(): if is_torch_tpu_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report()) xm.master_print(met.metrics_report())
...@@ -1265,16 +1266,16 @@ class Trainer: ...@@ -1265,16 +1266,16 @@ class Trainer:
if self.control.should_training_stop: if self.control.should_training_stop:
break break
if self.args.past_index and hasattr(self, "_past"): if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training # Clean the state at the end of training
delattr(self, "_past") delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None: if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
# Wait for everyone to get here so we are sur the model has been saved by process 0. # Wait for everyone to get here so we are sur the model has been saved by process 0.
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.rendezvous("load_best_model_at_end") xm.rendezvous("load_best_model_at_end")
elif self.args.local_rank != -1: elif args.local_rank != -1:
dist.barrier() dist.barrier()
logger.info( logger.info(
...@@ -1283,7 +1284,7 @@ class Trainer: ...@@ -1283,7 +1284,7 @@ class Trainer:
if isinstance(self.model, PreTrainedModel): if isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(self.state.best_model_checkpoint) self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
if self.place_model_on_device: if self.place_model_on_device:
self.model = self.model.to(self.args.device) self.model = self.model.to(args.device)
else: else:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
...@@ -1299,7 +1300,7 @@ class Trainer: ...@@ -1299,7 +1300,7 @@ class Trainer:
metrics["total_flos"] = self.state.total_flos metrics["total_flos"] = self.state.total_flos
self.log(metrics) self.log(metrics)
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control) self.control = self.callback_handler.on_train_end(args, self.state, self.control)
# add remaining tr_loss # add remaining tr_loss
self._total_loss_scalar += tr_loss.item() self._total_loss_scalar += tr_loss.item()
...@@ -1952,7 +1953,7 @@ class Trainer: ...@@ -1952,7 +1953,7 @@ class Trainer:
model = self._wrap_model(self.model, training=False) model = self._wrap_model(self.model, training=False)
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
# ``train`` is running, half it first and then put on device # ``train`` is running, halve it first and then put on device
if not self.is_in_train and self.args.fp16_full_eval: if not self.is_in_train and self.args.fp16_full_eval:
model = model.half().to(self.args.device) model = model.half().to(self.args.device)
...@@ -2288,7 +2289,7 @@ class Trainer: ...@@ -2288,7 +2289,7 @@ class Trainer:
model = self._wrap_model(self.model, training=False) model = self._wrap_model(self.model, training=False)
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
# ``train`` is running, half it first and then put on device # ``train`` is running, halve it first and then put on device
if not self.is_in_train and self.args.fp16_full_eval: if not self.is_in_train and self.args.fp16_full_eval:
model = model.half().to(self.args.device) model = model.half().to(self.args.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment