r1v.py 1.6 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

chenych's avatar
chenych committed
15
import re
chenych's avatar
chenych committed
16
from typing import Dict
chenych's avatar
chenych committed
17
18
19
20
21

from mathruler.grader import grade_answer


def r1v_format_reward(predict_str: str) -> float:
chenych's avatar
chenych committed
22
23
24
    pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
    format_match = re.fullmatch(pattern, predict_str)
    return 1.0 if format_match else 0.0
chenych's avatar
chenych committed
25
26
27
28
29
30


def r1v_accuracy_reward(predict_str: str, ground_truth: str) -> float:
    try:
        ground_truth = ground_truth.strip()
        content_match = re.search(r"<answer>(.*?)</answer>", predict_str)
chenych's avatar
chenych committed
31
32
        given_answer = content_match.group(1).strip() if content_match else predict_str.strip()
        if grade_answer(given_answer, ground_truth):
chenych's avatar
chenych committed
33
            return 1.0
chenych's avatar
chenych committed
34

chenych's avatar
chenych committed
35
36
37
38
39
40
    except Exception:
        pass

    return 0.0


chenych's avatar
chenych committed
41
42
43
44
45
46
47
48
def r1v_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]:
    format = r1v_format_reward(predict_str)
    accuracy = r1v_accuracy_reward(predict_str, ground_truth)
    return {
        "overall": 0.5 * accuracy + 0.5 * format,
        "format": format,
        "accuracy": accuracy,
    }