Unverified Commit 02d09c8f authored by Jin Young (Daniel) Sohn's avatar Jin Young (Daniel) Sohn Committed by GitHub
Browse files

Only access loss tensor every logging_steps (#6802)



* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* Fix style (#6803)

* t5 model should make decoder_attention_mask (#6800)

* [s2s] Test hub configs in self-scheduled CI (#6809)

* [s2s] round runtime in run_eval (#6798)

* Pegasus finetune script: add --adafactor (#6811)

* [bart] rename self-attention -> attention (#6708)

* [tests] fix typos in inputs (#6818)

* Fixed open in colab link (#6825)

* Add model card for singbert lite. Update widget for singbert and singbert-large. (#6827)

* BR_BERTo model card (#6793)

* clearly indicate shuffle=False (#6312)

* Clarify shuffle

* clarify shuffle
Co-authored-by: default avatarKevin Canwen Xu <canwenxu@126.com>

* [s2s README] Add more dataset download instructions (#6737)

* Style

* Patch logging issue

* Set default logging level to `WARNING` instead of `INFO`

* TF Flaubert w/ pre-norm (#6841)

* Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task (#6644)

* add datacollator and dataset for next sentence prediction task

* bug fix (numbers of special tokens & truncate sequences)

* bug fix (+ dict inputs support for data collator)

* add padding for nsp data collator; renamed cached files to avoid conflict.

* add test for nsp data collator

* Style
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>

* Fix in Adafactor docstrings (#6845)

* Fix resuming training for Windows (#6847)

* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* comments
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarThomas Ashish Cherian <6967017+PandaWhoCodes@users.noreply.github.com>
Co-authored-by: default avatarZane Lim <zyuanlim@gmail.com>
Co-authored-by: default avatarRodolfo De Nadai <rdenadai@gmail.com>
Co-authored-by: default avatarxujiaze13 <37360975+xujiaze13@users.noreply.github.com>
Co-authored-by: default avatarKevin Canwen Xu <canwenxu@126.com>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarHuang Lianzhe <hlz@pku.edu.cn>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent c48546c7
...@@ -658,8 +658,8 @@ class Trainer: ...@@ -658,8 +658,8 @@ class Trainer:
self.global_step = 0 self.global_step = 0
logger.info(" Starting fine-tuning.") logger.info(" Starting fine-tuning.")
tr_loss = 0.0 tr_loss = torch.tensor(0.0).to(self.args.device)
logging_loss = 0.0 logging_loss_scalar = 0.0
model.zero_grad() model.zero_grad()
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero() disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm) train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
...@@ -720,14 +720,15 @@ class Trainer: ...@@ -720,14 +720,15 @@ class Trainer:
self.global_step == 1 and self.args.logging_first_step self.global_step == 1 and self.args.logging_first_step
): ):
logs: Dict[str, float] = {} logs: Dict[str, float] = {}
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps tr_loss_scalar = tr_loss.item()
logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
# backward compatibility for pytorch schedulers # backward compatibility for pytorch schedulers
logs["learning_rate"] = ( logs["learning_rate"] = (
self.lr_scheduler.get_last_lr()[0] self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4") if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0] else self.lr_scheduler.get_lr()[0]
) )
logging_loss = tr_loss logging_loss_scalar = tr_loss_scalar
self.log(logs) self.log(logs)
...@@ -773,8 +774,6 @@ class Trainer: ...@@ -773,8 +774,6 @@ class Trainer:
break break
epoch_pbar.close() epoch_pbar.close()
train_pbar.update(1) train_pbar.update(1)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
break
if self.args.tpu_metrics_debug or self.args.debug: if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available(): if is_torch_tpu_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
...@@ -784,6 +783,8 @@ class Trainer: ...@@ -784,6 +783,8 @@ class Trainer:
"You enabled PyTorch/XLA debug metrics but you don't have a TPU " "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected." "configured. Check your training configuration if this is unexpected."
) )
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
break
train_pbar.close() train_pbar.close()
if self.tb_writer: if self.tb_writer:
...@@ -793,7 +794,7 @@ class Trainer: ...@@ -793,7 +794,7 @@ class Trainer:
delattr(self, "_past") delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
return TrainOutput(self.global_step, tr_loss / self.global_step) return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
def hyperparameter_search( def hyperparameter_search(
self, self,
...@@ -973,7 +974,7 @@ class Trainer: ...@@ -973,7 +974,7 @@ class Trainer:
return inputs return inputs
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> float: def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
""" """
Perform a training step on a batch of inputs. Perform a training step on a batch of inputs.
...@@ -989,7 +990,7 @@ class Trainer: ...@@ -989,7 +990,7 @@ class Trainer:
argument :obj:`labels`. Check your model's documentation for all accepted arguments. argument :obj:`labels`. Check your model's documentation for all accepted arguments.
Return: Return:
:obj:`float`: The training loss on this batch. :obj:`torch.Tensor`: The tensor with training loss on this batch.
""" """
if hasattr(self, "_training_step"): if hasattr(self, "_training_step"):
warnings.warn( warnings.warn(
...@@ -1027,7 +1028,7 @@ class Trainer: ...@@ -1027,7 +1028,7 @@ class Trainer:
else: else:
loss.backward() loss.backward()
return loss.item() return loss
def is_local_master(self) -> bool: def is_local_master(self) -> bool:
""" """
...@@ -1276,6 +1277,10 @@ class Trainer: ...@@ -1276,6 +1277,10 @@ class Trainer:
preds = xm.mesh_reduce("eval_preds", preds, torch.cat) preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
if label_ids is not None: if label_ids is not None:
label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)
if eval_losses is not None:
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
if samples_count is not None:
samples_count = sum(xm.mesh_reduce("samples_count", torch.tensor([samples_count]), torch.cat).tolist())
# Finally, turn the aggregated tensors into numpy arrays. # Finally, turn the aggregated tensors into numpy arrays.
if preds is not None: if preds is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment