"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "faacd747299f6bf3304f215c8570cbce867c870f"
Unverified Commit ba702966 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix cardinality (#9505)

parent 33b74228
...@@ -135,7 +135,7 @@ class TFTrainer: ...@@ -135,7 +135,7 @@ class TFTrainer:
raise ValueError("Trainer: training requires a train_dataset.") raise ValueError("Trainer: training requires a train_dataset.")
self.total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps self.total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps
self.num_train_examples = self.train_dataset.cardinality(self.train_dataset).numpy() self.num_train_examples = self.train_dataset.cardinality().numpy()
if self.num_train_examples < 0: if self.num_train_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality") raise ValueError("The training dataset must have an asserted cardinality")
...@@ -167,7 +167,7 @@ class TFTrainer: ...@@ -167,7 +167,7 @@ class TFTrainer:
raise ValueError("Trainer: evaluation requires an eval_dataset.") raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
num_examples = eval_dataset.cardinality(eval_dataset).numpy() num_examples = eval_dataset.cardinality().numpy()
if num_examples < 0: if num_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality") raise ValueError("The training dataset must have an asserted cardinality")
...@@ -197,7 +197,7 @@ class TFTrainer: ...@@ -197,7 +197,7 @@ class TFTrainer:
Subclass and override this method if you want to inject some custom behavior. Subclass and override this method if you want to inject some custom behavior.
""" """
num_examples = test_dataset.cardinality(test_dataset).numpy() num_examples = test_dataset.cardinality().numpy()
if num_examples < 0: if num_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality") raise ValueError("The training dataset must have an asserted cardinality")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment