Unverified Commit 64743d0a authored by Lukas Weiner's avatar Lukas Weiner Committed by GitHub
Browse files

Raise exceptions instead of asserts (#13938)

parent 32634bce
...@@ -196,12 +196,12 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -196,12 +196,12 @@ class SingleSentenceClassificationProcessor(DataProcessor):
def add_examples( def add_examples(
self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
): ):
assert labels is None or len(texts_or_text_and_labels) == len( if labels is not None and len(texts_or_text_and_labels) != len(labels):
labels raise ValueError(
), f"Text and labels have mismatched lengths {len(texts_or_text_and_labels)} and {len(labels)}" f"Text and labels have mismatched lengths {len(texts_or_text_and_labels)} and {len(labels)}"
assert ids is None or len(texts_or_text_and_labels) == len( )
ids if ids is not None and len(texts_or_text_and_labels) != len(ids):
), f"Text and ids have mismatched lengths {len(texts_or_text_and_labels)} and {len(ids)}" raise ValueError(f"Text and ids have mismatched lengths {len(texts_or_text_and_labels)} and {len(ids)}")
if ids is None: if ids is None:
ids = [None] * len(texts_or_text_and_labels) ids = [None] * len(texts_or_text_and_labels)
if labels is None: if labels is None:
...@@ -293,10 +293,10 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -293,10 +293,10 @@ class SingleSentenceClassificationProcessor(DataProcessor):
input_ids = input_ids + ([pad_token] * padding_length) input_ids = input_ids + ([pad_token] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
assert len(input_ids) == batch_length, f"Error with input length {len(input_ids)} vs {batch_length}" if len(input_ids) != batch_length:
assert ( raise ValueError(f"Error with input length {len(input_ids)} vs {batch_length}")
len(attention_mask) == batch_length if len(attention_mask) != batch_length:
), f"Error with input length {len(attention_mask)} vs {batch_length}" raise ValueError(f"Error with input length {len(attention_mask)} vs {batch_length}")
if self.mode == "classification": if self.mode == "classification":
label = label_map[example.label] label = label_map[example.label]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment