Unverified Commit 5c593718 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix DataCollatorForSeq2Seq when labels are supplied as Numpy array instead of list (#13582)

* Fix issue when labels are supplied as Numpy array instead of list

* Fix issue when labels are supplied as Numpy array instead of list
parent 421929b5
...@@ -517,6 +517,8 @@ class DataCollatorForSeq2Seq: ...@@ -517,6 +517,8 @@ class DataCollatorForSeq2Seq:
return_tensors: str = "pt" return_tensors: str = "pt"
def __call__(self, features, return_tensors=None): def __call__(self, features, return_tensors=None):
import numpy as np
if return_tensors is None: if return_tensors is None:
return_tensors = self.return_tensors return_tensors = self.return_tensors
labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
...@@ -527,9 +529,14 @@ class DataCollatorForSeq2Seq: ...@@ -527,9 +529,14 @@ class DataCollatorForSeq2Seq:
padding_side = self.tokenizer.padding_side padding_side = self.tokenizer.padding_side
for feature in features: for feature in features:
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"])) remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
feature["labels"] = ( if isinstance(feature["labels"], list):
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"] feature["labels"] = (
) feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
)
elif padding_side == "right":
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
else:
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
features = self.tokenizer.pad( features = self.tokenizer.pad(
features, features,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment