math.py 1.71 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
chenych's avatar
chenych committed
16
from typing import Dict, List
chenych's avatar
chenych committed
17

chenych's avatar
chenych committed
18
19
20
from mathruler.grader import extract_boxed_content, grade_answer


chenych's avatar
chenych committed
21
def format_reward(predict: str) -> float:
chenych's avatar
chenych committed
22
    pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
chenych's avatar
chenych committed
23
    format_match = re.fullmatch(pattern, predict)
chenych's avatar
chenych committed
24
25
26
    return 1.0 if format_match else 0.0


chenych's avatar
chenych committed
27
28
def accuracy_reward(predict: str, ground_truth: str) -> float:
    answer = extract_boxed_content(predict)
chenych's avatar
chenych committed
29
    return 1.0 if grade_answer(answer, ground_truth) else 0.0
chenych's avatar
chenych committed
30
31


chenych's avatar
chenych committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]:
    scores = []
    for predict, ground_truth in zip(predicts, ground_truths):
        predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)  # handle qwen2.5vl-32b format
        format_score = format_reward(predict)
        accuracy_score = accuracy_reward(predict, ground_truth)
        scores.append(
            {
                "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
                "format": format_score,
                "accuracy": accuracy_score,
            }
        )

    return scores