"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "3b2e50a16a2c2ecf8eff63444964e8c4381fcdf9"
Unverified Commit 35befd9c authored by Joe Davison's avatar Joe Davison Committed by GitHub
Browse files

Fix tensor label type inference in default collator (#5250)

* allow tensor label inputs to default collator

* replace try/except with type check
parent fe81f7d1
...@@ -43,7 +43,8 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten ...@@ -43,7 +43,8 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
# Ensure that tensor is created with the correct type # Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.) # (it should be automatically the case, but let's make sure of it.)
if "label" in first and first["label"] is not None: if "label" in first and first["label"] is not None:
dtype = torch.long if type(first["label"]) is int else torch.float label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
dtype = torch.long if isinstance(label, int) else torch.float
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype) batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
elif "label_ids" in first and first["label_ids"] is not None: elif "label_ids" in first and first["label_ids"] is not None:
if isinstance(first["label_ids"], torch.Tensor): if isinstance(first["label_ids"], torch.Tensor):
......
...@@ -44,6 +44,14 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -44,6 +44,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["labels"].dtype, torch.long) self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10])) self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
# Labels can already be tensors
features = [{"label": torch.tensor(i), "inputs": torch.randint(10, [10])} for i in range(8)]
batch = default_data_collator(features)
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
def test_default_with_no_labels(self): def test_default_with_no_labels(self):
features = [{"label": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)] features = [{"label": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features) batch = default_data_collator(features)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment