Unverified Commit e5904168 authored by Matt's avatar Matt Committed by GitHub
Browse files

DataCollatorForTokenClassification numpy fix (#13609)

* Fix issue when labels are supplied as Numpy array instead of list

* Fix issue when labels are supplied as Numpy array instead of list

* Fix same issue in the `TokenClassification` data collator

* Style pass
parent 88dbbfb2
......@@ -291,11 +291,11 @@ class DataCollatorForTokenClassification(DataCollatorMixin):
padding_side = self.tokenizer.padding_side
if padding_side == "right":
batch[label_name] = [
label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
]
else:
batch[label_name] = [
[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
]
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
......@@ -321,9 +321,13 @@ class DataCollatorForTokenClassification(DataCollatorMixin):
sequence_length = tf.convert_to_tensor(batch["input_ids"]).shape[1]
padding_side = self.tokenizer.padding_side
if padding_side == "right":
batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
batch["labels"] = [
list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
]
else:
batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
batch["labels"] = [
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
]
batch = {k: tf.convert_to_tensor(v, dtype=tf.int64) for k, v in batch.items()}
return batch
......@@ -348,9 +352,13 @@ class DataCollatorForTokenClassification(DataCollatorMixin):
sequence_length = np.array(batch["input_ids"]).shape[1]
padding_side = self.tokenizer.padding_side
if padding_side == "right":
batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
batch["labels"] = [
list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
]
else:
batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
batch["labels"] = [
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
]
batch = {k: np.array(v, dtype=np.int64) for k, v in batch.items()}
return batch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment