import re
def extract_xml_answer(text: str) -> str:
answer = text.split("")[-1]
answer = answer.split("")[0]
return answer.strip()
# Define format reward function
def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r".*?\s*.*?"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
# print("matchs:", matches)
return [0.5 if match else 0.0 for match in matches]
# Define accuracy reward function
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
q = prompts[0][-1]['content']
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
# print('-'*20)
# print(f"Question:\n{q}", f"\nLabel:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^\n.*?\n\n\n.*?\n\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text) -> float:
count = 0.0
if text.count("\n") == 1:
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
count -= len(text.split("\n\n")[-1])*0.001
if text.count("\n") == 1:
count += 0.125
count -= (len(text.split("\n")[-1]) - 1)*0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]