Unverified Commit 9ad83059 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix dataset cardinality (#7678)

* Fix test

* Fix cardinality issue

* Fix test
parent a1ac0828
...@@ -96,6 +96,9 @@ def get_tfds( ...@@ -96,6 +96,9 @@ def get_tfds(
else None else None
) )
if train_ds is not None:
train_ds = train_ds.apply(tf.data.experimental.assert_cardinality(len(ds[datasets.Split.TRAIN])))
val_ds = ( val_ds = (
tf.data.Dataset.from_generator( tf.data.Dataset.from_generator(
gen_val, gen_val,
...@@ -106,6 +109,9 @@ def get_tfds( ...@@ -106,6 +109,9 @@ def get_tfds(
else None else None
) )
if val_ds is not None:
val_ds = val_ds.apply(tf.data.experimental.assert_cardinality(len(ds[datasets.Split.VALIDATION])))
test_ds = ( test_ds = (
tf.data.Dataset.from_generator( tf.data.Dataset.from_generator(
gen_test, gen_test,
...@@ -116,6 +122,9 @@ def get_tfds( ...@@ -116,6 +122,9 @@ def get_tfds(
else None else None
) )
if test_ds is not None:
test_ds = test_ds.apply(tf.data.experimental.assert_cardinality(len(ds[datasets.Split.TEST])))
return train_ds, val_ds, test_ds, label2id return train_ds, val_ds, test_ds, label2id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment