Unverified Commit 84b9579d authored by Matt's avatar Matt Committed by GitHub
Browse files

Remove unneeded `to_tensor()` in TF inline example (#14140)

parent 1967c43e
...@@ -240,11 +240,11 @@ Then we convert everything in big tensors and use the :obj:`tf.data.Dataset.from ...@@ -240,11 +240,11 @@ Then we convert everything in big tensors and use the :obj:`tf.data.Dataset.from
.. code-block:: python .. code-block:: python
train_features = {x: tf_train_dataset[x].to_tensor() for x in tokenizer.model_input_names} train_features = {x: tf_train_dataset[x] for x in tokenizer.model_input_names}
train_tf_dataset = tf.data.Dataset.from_tensor_slices((train_features, tf_train_dataset["label"])) train_tf_dataset = tf.data.Dataset.from_tensor_slices((train_features, tf_train_dataset["label"]))
train_tf_dataset = train_tf_dataset.shuffle(len(tf_train_dataset)).batch(8) train_tf_dataset = train_tf_dataset.shuffle(len(tf_train_dataset)).batch(8)
eval_features = {x: tf_eval_dataset[x].to_tensor() for x in tokenizer.model_input_names} eval_features = {x: tf_eval_dataset[x] for x in tokenizer.model_input_names}
eval_tf_dataset = tf.data.Dataset.from_tensor_slices((eval_features, tf_eval_dataset["label"])) eval_tf_dataset = tf.data.Dataset.from_tensor_slices((eval_features, tf_eval_dataset["label"]))
eval_tf_dataset = eval_tf_dataset.batch(8) eval_tf_dataset = eval_tf_dataset.batch(8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment