"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
"You need to run the TensorFlow trainer with at least the version 2.2.0, your version is %r "
"You need to run the TensorFlow trainer with at least the version 2.2.0, your version is %r "
...
@@ -80,11 +80,17 @@ class TFTrainer:
...
@@ -80,11 +80,17 @@ class TFTrainer:
self.train_dataset=train_dataset
self.train_dataset=train_dataset
self.eval_dataset=eval_dataset
self.eval_dataset=eval_dataset
self.compute_metrics=compute_metrics
self.compute_metrics=compute_metrics
self.prediction_loss_only=prediction_loss_only
self.optimizer,self.lr_scheduler=optimizers
self.optimizer,self.lr_scheduler=optimizers
self.gradient_accumulator=GradientAccumulator()
self.gradient_accumulator=GradientAccumulator()
self.global_step=0
self.global_step=0
self.epoch_logging=0
self.epoch_logging=0
if"prediction_loss_only"inkwargs:
warnings.warn(
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",