"docs/git@developer.sourcefind.cn:guobj/qwen_lmdeploy.git" did not exist on "271a19fe49861737e04d293eb3f1b923c0ce0104"
Unverified Commit f00f22a3 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fixes tf_default_data_collator sometimes guessing the wrong dtype for labels (#15234)

* Fixes tf_default_data_collator sometimes guessing the wrong dtype for labels

* Add test for numpy scalar inputs
parent 4a6a35bc
...@@ -145,26 +145,27 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]: ...@@ -145,26 +145,27 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
# Ensure that tensor is created with the correct type # Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.) # (it should be automatically the case, but let's make sure of it.)
if "label" in first and first["label"] is not None: if "label" in first and first["label"] is not None:
if isinstance(first["label"], tf.Tensor): label_col_name = "label"
dtype = tf.int64 if first["label"].dtype.is_integer() else tf.float32
elif isinstance(first["label"], np.ndarray):
dtype = tf.int64 if np.issubdtype(first["label"].dtype, np.integer) else tf.float32
elif isinstance(first["label"], (tuple, list)):
dtype = tf.int64 if isinstance(first["label"][0], int) else tf.float32
else:
dtype = tf.int64 if isinstance(first["label"], int) else tf.float32
batch["labels"] = tf.convert_to_tensor([f["label"] for f in features], dtype=dtype)
elif "label_ids" in first and first["label_ids"] is not None: elif "label_ids" in first and first["label_ids"] is not None:
if isinstance(first["label_ids"], tf.Tensor): label_col_name = "label_ids"
batch["labels"] = tf.stack([f["label_ids"] for f in features]) elif "labels" in first and first["labels"] is not None:
label_col_name = "labels"
else:
label_col_name = None
if label_col_name is not None:
if isinstance(first[label_col_name], tf.Tensor):
dtype = tf.int64 if first[label_col_name].dtype.is_integer() else tf.float32
elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
elif isinstance(first[label_col_name], (tuple, list)):
dtype = tf.int64 if isinstance(first[label_col_name][0], int) else tf.float32
else: else:
dtype = tf.int64 if type(first["label_ids"][0]) is int else tf.float32 dtype = tf.int64 if isinstance(first[label_col_name], int) else tf.float32
batch["labels"] = tf.convert_to_tensor([f["label_ids"] for f in features], dtype=dtype) batch["labels"] = tf.convert_to_tensor([f[label_col_name] for f in features], dtype=dtype)
# Handling of all other possible keys. # Handling of all other possible keys.
# Again, we will use the first element to figure out which key/values are not None for this model. # Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in first.items(): for k, v in first.items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str): if k not in ("label", "label_ids", "labels") and v is not None and not isinstance(v, str):
if isinstance(v, (tf.Tensor, np.ndarray)): if isinstance(v, (tf.Tensor, np.ndarray)):
batch[k] = tf.stack([f[k] for f in features]) batch[k] = tf.stack([f[k] for f in features])
else: else:
......
...@@ -353,6 +353,14 @@ class TFDataCollatorIntegrationTest(unittest.TestCase): ...@@ -353,6 +353,14 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["labels"].dtype, tf.int64) self.assertEqual(batch["labels"].dtype, tf.int64)
self.assertEqual(batch["inputs"].shape.as_list(), [8, 10]) self.assertEqual(batch["inputs"].shape.as_list(), [8, 10])
def test_numpy_dtype_preservation(self):
data_collator = default_data_collator
# Confirms that numpy inputs are handled correctly even when scalars
features = [{"input_ids": np.array([0, 1, 2, 3, 4]), "label": np.int64(i)} for i in range(4)]
batch = data_collator(features, return_tensors="tf")
self.assertEqual(batch["labels"].dtype, tf.int64)
def test_default_classification_and_regression(self): def test_default_classification_and_regression(self):
data_collator = default_data_collator data_collator = default_data_collator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment