Unverified Commit fcf06524 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix TensorFlow dataset generator (#4881)

* fix TensorFlow generator

* Better features handling

* Apply style

* Apply style

* Fix squad as well

* Apply style

* Better factorization of TF Tensors creation
parent 501040fd
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import logging import logging
import os import os
from dataclasses import asdict
from enum import Enum from enum import Enum
from typing import List, Optional, Union from typing import List, Optional, Union
...@@ -81,26 +82,16 @@ if is_tf_available(): ...@@ -81,26 +82,16 @@ if is_tf_available():
def gen(): def gen():
for ex in features: for ex in features:
yield ( d = {k: v for k, v in asdict(ex).items() if v is not None}
{ label = d.pop("label")
"input_ids": ex.input_ids, yield (d, label)
"attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids, input_names = ["input_ids"] + tokenizer.model_input_names
},
ex.label,
)
return tf.data.Dataset.from_generator( return tf.data.Dataset.from_generator(
gen, gen,
({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64), ({k: tf.int32 for k in input_names}, tf.int64),
( ({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])),
{
"input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]),
"token_type_ids": tf.TensorShape([None]),
},
tf.TensorShape([]),
),
) )
......
...@@ -389,6 +389,23 @@ def squad_convert_examples_to_features( ...@@ -389,6 +389,23 @@ def squad_convert_examples_to_features(
def gen(): def gen():
for i, ex in enumerate(features): for i, ex in enumerate(features):
if ex.token_type_ids is None:
yield (
{
"input_ids": ex.input_ids,
"attention_mask": ex.attention_mask,
"feature_index": i,
"qas_id": ex.qas_id,
},
{
"start_positions": ex.start_position,
"end_positions": ex.end_position,
"cls_index": ex.cls_index,
"p_mask": ex.p_mask,
"is_impossible": ex.is_impossible,
},
)
else:
yield ( yield (
{ {
"input_ids": ex.input_ids, "input_ids": ex.input_ids,
...@@ -407,6 +424,7 @@ def squad_convert_examples_to_features( ...@@ -407,6 +424,7 @@ def squad_convert_examples_to_features(
) )
# Why have we split the batch into a tuple? PyTorch just has a list of tensors. # Why have we split the batch into a tuple? PyTorch just has a list of tensors.
if "token_type_ids" in tokenizer.model_input_names:
train_types = ( train_types = (
{ {
"input_ids": tf.int32, "input_ids": tf.int32,
...@@ -440,6 +458,33 @@ def squad_convert_examples_to_features( ...@@ -440,6 +458,33 @@ def squad_convert_examples_to_features(
"is_impossible": tf.TensorShape([]), "is_impossible": tf.TensorShape([]),
}, },
) )
else:
train_types = (
{"input_ids": tf.int32, "attention_mask": tf.int32, "feature_index": tf.int64, "qas_id": tf.string},
{
"start_positions": tf.int64,
"end_positions": tf.int64,
"cls_index": tf.int64,
"p_mask": tf.int32,
"is_impossible": tf.int32,
},
)
train_shapes = (
{
"input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]),
"feature_index": tf.TensorShape([]),
"qas_id": tf.TensorShape([]),
},
{
"start_positions": tf.TensorShape([]),
"end_positions": tf.TensorShape([]),
"cls_index": tf.TensorShape([]),
"p_mask": tf.TensorShape([None]),
"is_impossible": tf.TensorShape([]),
},
)
return tf.data.Dataset.from_generator(gen, train_types, train_shapes) return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment