custom.py 2.32 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


chenych's avatar
chenych committed
16
17
18
from collections import defaultdict
from typing import Any, Callable, Dict, Tuple, TypedDict

chenych's avatar
chenych committed
19
20
21
import torch
from transformers import PreTrainedTokenizer

chenych's avatar
chenych committed
22
23
24
25
26
27
28
29
from ...protocol import DataProto
from ...utils.reward_score import math_compute_score, r1v_compute_score


class RewardScore(TypedDict):
    overall: float
    format: float
    accuracy: float
chenych's avatar
chenych committed
30
31
32


class CustomRewardManager:
chenych's avatar
chenych committed
33
    def __init__(self, tokenizer: PreTrainedTokenizer, compute_score: str):
chenych's avatar
chenych committed
34
35
        self.tokenizer = tokenizer
        if compute_score == "math":
chenych's avatar
chenych committed
36
            self.compute_score: Callable[[str, str], RewardScore] = math_compute_score
chenych's avatar
chenych committed
37
        elif compute_score == "r1v":
chenych's avatar
chenych committed
38
            self.compute_score: Callable[[str, str], RewardScore] = r1v_compute_score
chenych's avatar
chenych committed
39
40
41
        else:
            raise NotImplementedError()

chenych's avatar
chenych committed
42
    def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, Any]]:
chenych's avatar
chenych committed
43
        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
chenych's avatar
chenych committed
44
        reward_metrics = defaultdict(list)
chenych's avatar
chenych committed
45
46
47
        for i in range(len(data)):
            data_item = data[i]  # DataProtoItem
            response_ids = data_item.batch["responses"]
chenych's avatar
chenych committed
48
49
            response_mask = data_item.batch["response_mask"]
            valid_response_length = response_mask.sum()
chenych's avatar
chenych committed
50
51
52
            valid_response_ids = response_ids[:valid_response_length]

            response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
chenych's avatar
chenych committed
53
            ground_truth = data_item.non_tensor_batch["ground_truth"]
chenych's avatar
chenych committed
54
55

            score = self.compute_score(response_str, ground_truth)
chenych's avatar
chenych committed
56
57
58
            reward_tensor[i, valid_response_length - 1] = score["overall"]
            for key, value in score.items():
                reward_metrics[key].append(value)
chenych's avatar
chenych committed
59

chenych's avatar
chenych committed
60
        return reward_tensor, reward_metrics