"docs/backend/sampling_params.md" did not exist on "8c3b420eec03ea94e4ccce04681891558ca892ca"
math.py 1.54 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

chenych's avatar
chenych committed
15
import re
chenych's avatar
chenych committed
16
from typing import Dict
chenych's avatar
chenych committed
17

chenych's avatar
v0.3.0  
chenych committed
18
from mathruler.grader import extract_boxed_content, grade_answer
chenych's avatar
chenych committed
19
20


chenych's avatar
v0.3.0  
chenych committed
21
22
23
def format_reward(predict_str: str) -> float:
    pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
    format_match = re.fullmatch(pattern, predict_str)
chenych's avatar
chenych committed
24
    return 1.0 if format_match else 0.0
chenych's avatar
chenych committed
25
26


chenych's avatar
v0.3.0  
chenych committed
27
28
29
def accuracy_reward(predict_str: str, ground_truth: str) -> float:
    answer = extract_boxed_content(predict_str)
    return 1.0 if grade_answer(answer, ground_truth) else 0.0
chenych's avatar
chenych committed
30

chenych's avatar
chenych committed
31

chenych's avatar
v0.3.0  
chenych committed
32
33
34
35
def compute_score(predict_str: str, ground_truth: str, format_weight: float = 0.1) -> Dict[str, float]:
    predict_str = re.sub(r"\s*(<|>|/)\s*", r"\1", predict_str)  # handle qwen2.5vl-32b format
    format_score = format_reward(predict_str)
    accuracy_score = accuracy_reward(predict_str, ground_truth)
chenych's avatar
chenych committed
36
    return {
chenych's avatar
update  
chenych committed
37
38
39
        "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
        "format": format_score,
        "accuracy": accuracy_score,
chenych's avatar
chenych committed
40
    }