Unverified Commit 634a14bd authored by Kayvan Mivehnejad's avatar Kayvan Mivehnejad Committed by GitHub
Browse files

Strengthen input validation and tests for 'parse_raw_prompts’. (#30652)


Signed-off-by: default avatarKayvan Mivehnejad <K.Mivehnejad@gmail.com>
parent 24b65eff
......@@ -45,16 +45,17 @@ def parse_raw_prompts(
# case 4: array of token arrays
if is_list_of(prompt, list):
first = prompt[0]
if not isinstance(first, list):
raise ValueError("prompt expected to be a list of lists")
if len(first) == 0:
raise ValueError("Please provide at least one prompt")
# strict validation: every nested list must be list[int]
if not all(is_list_of(elem, int) for elem in prompt):
raise TypeError("Nested lists must contain only integers")
if len(prompt) == 1 and isinstance(prompt[0], list) and len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
for elem in prompt:
if not isinstance(elem, list):
raise TypeError(
"prompt must be a list of lists, but found a non-list element."
)
if not is_list_of(elem, int):
raise TypeError(
"Nested lists of tokens must contain only integers."
)
prompt = cast(list[list[int]], prompt)
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment