r1v.py 1.61 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

chenych's avatar
chenych committed
15
import re
chenych's avatar
chenych committed
16
from typing import Dict
chenych's avatar
chenych committed
17
18
19
20

from mathruler.grader import grade_answer


chenych's avatar
chenych committed
21
def format_reward(predict: str) -> float:
chenych's avatar
chenych committed
22
    pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
chenych's avatar
chenych committed
23
    format_match = re.fullmatch(pattern, predict)
chenych's avatar
chenych committed
24
    return 1.0 if format_match else 0.0
chenych's avatar
chenych committed
25
26


chenych's avatar
chenych committed
27
def accuracy_reward(predict: str, ground_truth: str) -> float:
chenych's avatar
chenych committed
28
    try:
chenych's avatar
chenych committed
29
30
        content_match = re.search(r"<answer>(.*?)</answer>", predict)
        given_answer = content_match.group(1).strip() if content_match else predict.strip()
chenych's avatar
update  
chenych committed
31
        if grade_answer(given_answer, ground_truth.strip()):
chenych's avatar
chenych committed
32
            return 1.0
chenych's avatar
chenych committed
33

chenych's avatar
chenych committed
34
35
36
37
38
39
    except Exception:
        pass

    return 0.0


chenych's avatar
chenych committed
40
41
42
def compute_score(predict: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]:
    format_score = format_reward(predict)
    accuracy_score = accuracy_reward(predict, ground_truth)
chenych's avatar
chenych committed
43
    return {
chenych's avatar
update  
chenych committed
44
45
46
        "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
        "format": format_score,
        "accuracy": accuracy_score,
chenych's avatar
chenych committed
47
    }