Unverified Commit 5a81195e authored by Ali Modarressi's avatar Ali Modarressi Committed by GitHub
Browse files

Fixed label datatype for STS-B (#6492)

* fixed label datatype for sts-b

* naming update

* make style

* make style
parent 12d76241
...@@ -79,6 +79,7 @@ if is_tf_available(): ...@@ -79,6 +79,7 @@ if is_tf_available():
processor = glue_processors[task]() processor = glue_processors[task]()
examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples] examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples]
features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task) features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
label_type = tf.float32 if task == "sts-b" else tf.int64
def gen(): def gen():
for ex in features: for ex in features:
...@@ -90,7 +91,7 @@ if is_tf_available(): ...@@ -90,7 +91,7 @@ if is_tf_available():
return tf.data.Dataset.from_generator( return tf.data.Dataset.from_generator(
gen, gen,
({k: tf.int32 for k in input_names}, tf.int64), ({k: tf.int32 for k in input_names}, label_type),
({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])), ({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment