Unverified Commit 2b71531a authored by Kangyan-Zhou's avatar Kangyan-Zhou Committed by GitHub
Browse files

Allow benchmarking tool to handle empty response (#12174)


Co-authored-by: default avatarClaude <noreply@anthropic.com>
parent 25c50498
...@@ -148,7 +148,7 @@ class ChatCompletionSampler(SamplerBase): ...@@ -148,7 +148,7 @@ class ChatCompletionSampler(SamplerBase):
reasoning_effort=self.reasoning_effort, reasoning_effort=self.reasoning_effort,
extra_body=self.extra_body, extra_body=self.extra_body,
) )
return response.choices[0].message.content return response.choices[0].message.content or ""
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
except openai.BadRequestError as e: except openai.BadRequestError as e:
print("Bad Request Error", e) print("Bad Request Error", e)
...@@ -161,7 +161,9 @@ class ChatCompletionSampler(SamplerBase): ...@@ -161,7 +161,9 @@ class ChatCompletionSampler(SamplerBase):
) )
time.sleep(exception_backoff) time.sleep(exception_backoff)
trial += 1 trial += 1
# unknown error shall throw exception # If all retries are exhausted, return empty string instead of None
print(f"All retry attempts exhausted for request. Returning empty response.")
return ""
QUERY_TEMPLATE_MULTICHOICE = """ QUERY_TEMPLATE_MULTICHOICE = """
...@@ -261,7 +263,7 @@ def format_multichoice_question(row): ...@@ -261,7 +263,7 @@ def format_multichoice_question(row):
def check_equality(sampler: SamplerBase, expr1: str, expr2: str): def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2} prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2}
response = sampler([dict(content=prompt, role="user")]) response = sampler([dict(content=prompt, role="user")])
return response.lower().strip() == "yes" return (response or "").lower().strip() == "yes"
def _compute_stat(values: list, stat: str): def _compute_stat(values: list, stat: str):
......
...@@ -80,6 +80,7 @@ class HumanEval(Eval): ...@@ -80,6 +80,7 @@ class HumanEval(Eval):
instruction = "Read the following function signature and docstring, and fully implement the function described. Your response should only contain the code for this function.\n" instruction = "Read the following function signature and docstring, and fully implement the function described. Your response should only contain the code for this function.\n"
def find_code(completion): def find_code(completion):
completion = completion or ""
pattern = re.compile(r"```python\n(.*?)```", re.DOTALL) pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
matches = pattern.findall(completion) matches = pattern.findall(completion)
extracted_answer = matches[0] if len(matches) >= 1 else completion extracted_answer = matches[0] if len(matches) >= 1 else completion
......
...@@ -54,6 +54,7 @@ class MathEval(Eval): ...@@ -54,6 +54,7 @@ class MathEval(Eval):
sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user") sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user")
] ]
response_text = sampler(prompt_messages) response_text = sampler(prompt_messages)
response_text = response_text or ""
match = re.search(ANSWER_PATTERN, response_text) match = re.search(ANSWER_PATTERN, response_text)
extracted_answer = match.group(1) if match else None extracted_answer = match.group(1) if match else None
score = float( score = float(
......
...@@ -101,6 +101,7 @@ class MMLUEval(Eval): ...@@ -101,6 +101,7 @@ class MMLUEval(Eval):
) )
] ]
response_text = sampler(prompt_messages) response_text = sampler(prompt_messages)
response_text = response_text or ""
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text) match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
extracted_answer = match.group(1) if match else None extracted_answer = match.group(1) if match else None
score = 1.0 if extracted_answer == row["Answer"] else 0.0 score = 1.0 if extracted_answer == row["Answer"] else 0.0
......
...@@ -204,6 +204,7 @@ class MMMUVLMEval(Eval): ...@@ -204,6 +204,7 @@ class MMMUVLMEval(Eval):
# Sample # Sample
response_text = sampler(prompt_messages) response_text = sampler(prompt_messages)
response_text = response_text or ""
# Parse and score # Parse and score
gold = sample["answer"] gold = sample["answer"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment