"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "18eec3a9847da4c879a3af8c5a57e9aaf70adf6d"
Unverified Commit 6312fed4 authored by Maria Janina Sarol's avatar Maria Janina Sarol Committed by GitHub
Browse files

Fix TFTrainer prediction output (#9662)

* Fix TFTrainer prediction output

* Update trainer_tf.py

* Fix TFTrainer prediction output

* Fix evaluation_loss update in TFTrainer

* Fix TFTrainer prediction output
parent 9152f160
...@@ -101,6 +101,7 @@ class TFTrainer: ...@@ -101,6 +101,7 @@ class TFTrainer:
self.gradient_accumulator = GradientAccumulator() self.gradient_accumulator = GradientAccumulator()
self.global_step = 0 self.global_step = 0
self.epoch_logging = 0 self.epoch_logging = 0
self.eval_loss = tf.keras.metrics.Sum()
if tb_writer is not None: if tb_writer is not None:
self.tb_writer = tb_writer self.tb_writer = tb_writer
...@@ -202,13 +203,8 @@ class TFTrainer: ...@@ -202,13 +203,8 @@ class TFTrainer:
if num_examples < 0: if num_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality") raise ValueError("The training dataset must have an asserted cardinality")
approx = math.floor if self.args.dataloader_drop_last else math.ceil steps = math.ceil(num_examples / self.args.eval_batch_size)
steps = approx(num_examples / self.args.eval_batch_size) ds = test_dataset.batch(self.args.eval_batch_size).prefetch(tf.data.experimental.AUTOTUNE)
ds = (
test_dataset.repeat()
.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
.prefetch(tf.data.experimental.AUTOTUNE)
)
return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
...@@ -300,12 +296,14 @@ class TFTrainer: ...@@ -300,12 +296,14 @@ class TFTrainer:
) )
logger.info("***** Running %s *****", description) logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", num_examples) logger.info(" Num examples in dataset = %d", num_examples)
if description == "Evaluation":
logger.info(" Num examples in used in evaluation = %d", self.args.eval_batch_size * steps)
logger.info(" Batch size = %d", self.args.eval_batch_size) logger.info(" Batch size = %d", self.args.eval_batch_size)
label_ids: np.ndarray = None label_ids: np.ndarray = None
preds: np.ndarray = None preds: np.ndarray = None
self.eval_loss = tf.keras.metrics.Sum() self.eval_loss.reset_states()
# Reset the past mems state at the beginning of the evaluation if necessary. # Reset the past mems state at the beginning of the evaluation if necessary.
if self.args.past_index >= 0: if self.args.past_index >= 0:
...@@ -345,7 +343,7 @@ class TFTrainer: ...@@ -345,7 +343,7 @@ class TFTrainer:
else: else:
label_ids = np.append(label_ids, labels.numpy(), axis=0) label_ids = np.append(label_ids, labels.numpy(), axis=0)
if step == steps: if step == steps - 1:
break break
if self.compute_metrics is not None and preds is not None and label_ids is not None: if self.compute_metrics is not None and preds is not None and label_ids is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment