# This file is adapted from code in https://github.com/Jiayi-Pan/TinyZero import re from datasets import load_dataset raw_dataset = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split='train') TRAIN_SIZE = 327680 TEST_SIZE = 1024 assert len(raw_dataset) > TRAIN_SIZE + TEST_SIZE train_dataset = raw_dataset.select(range(TRAIN_SIZE)) test_dataset = raw_dataset.select(range(TRAIN_SIZE, TRAIN_SIZE + TEST_SIZE)) def extract_solution(solution_str: str, remove_prompt: bool = False): """Extract the equation from the solution string.""" if remove_prompt: # Remove everything before the first "Assistant:" if "Assistant:" in solution_str: solution_str = solution_str.split("Assistant:", 1)[1] elif "<|im_start|>assistant" in solution_str: solution_str = solution_str.split("<|im_start|>assistant", 1)[1] else: return None solution_str = solution_str.split('\n')[-1] answer_pattern = r'(.*?)' match = re.finditer(answer_pattern, solution_str) matches = list(match) if matches: final_answer = matches[-1].group(1).strip() else: final_answer = None return final_answer def validate_equation(equation_str, available_numbers): """Validate that equation only uses available numbers and each number once.""" try: # Extract all numbers from the equation numbers_in_eq = [int(n) for n in re.findall(r'\d+', equation_str)] # Check if all numbers in equation are available available_numbers = sorted(available_numbers) numbers_in_eq = sorted(numbers_in_eq) # Each number should be used exactly once return numbers_in_eq == available_numbers except: return False def evaluate_equation(equation_str): """Safely evaluate the arithmetic equation using eval() with precautions.""" try: # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace allowed_pattern = r'^[\d+\-*/().\s]+$' if not re.match(allowed_pattern, equation_str): raise ValueError("Invalid characters in equation.") # Evaluate the equation with restricted globals and locals result = eval(equation_str, {"__builtins__": None}, {}) return result except Exception as e: return None def compute_score(solution_str, ground_truth, method='strict', format_score=0.1, score=1.0): """The scoring function for countdown task. Args: solution_str: the solution text ground_truth: dictionary containing target number and available numbers method: the method to extract the solution format_score: the score for correct format but wrong answer score: the score for the correct answer """ target = ground_truth['target'] numbers = ground_truth['nums'] equation = extract_solution(solution_str=solution_str) do_print = False # random.randint(1, 64) == 1 if do_print: print("--------------------------------") print(f"Target: {target} | Numbers: {numbers}") print(f"Extracted equation: {equation}") print(f"Solution string: {solution_str}") if equation is None: if do_print: print("No equation found") return 0 # Validate equation uses correct numbers if not validate_equation(equation, numbers): if do_print: print("Invalid equation") return format_score # Evaluate equation try: result = evaluate_equation(equation) if result is None: if do_print: print("Could not evaluate equation") return format_score if abs(result - target) < 1e-5: # Account for floating point precision if do_print: print(f"Correct equation: {equation} = {result}") return score else: if do_print: print(f"Wrong result: equation = {result}, target = {target}") return format_score except: if do_print: print("Error evaluating equation") return format_score