custom.py 2.51 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


chenych's avatar
chenych committed
16
from collections import defaultdict
chenych's avatar
Update  
chenych committed
17
from typing import Callable, Dict, List, Tuple, TypedDict
chenych's avatar
chenych committed
18

chenych's avatar
chenych committed
19
20
21
import torch
from transformers import PreTrainedTokenizer

chenych's avatar
chenych committed
22
23
from ...protocol import DataProto
from ...utils.reward_score import math_compute_score, r1v_compute_score
chenych's avatar
Update  
chenych committed
24
from .config import RewardConfig
chenych's avatar
chenych committed
25
26
27
28
29
30


class RewardScore(TypedDict):
    overall: float
    format: float
    accuracy: float
chenych's avatar
chenych committed
31
32
33


class CustomRewardManager:
chenych's avatar
Update  
chenych committed
34
35
    def __init__(self, tokenizer: PreTrainedTokenizer, config: RewardConfig):
        self.config = config
chenych's avatar
chenych committed
36
        self.tokenizer = tokenizer
chenych's avatar
Update  
chenych committed
37
        if config.score_function == "math":
chenych's avatar
chenych committed
38
            self.compute_score: Callable[[str, str], RewardScore] = math_compute_score
chenych's avatar
Update  
chenych committed
39
        elif config.score_function == "r1v":
chenych's avatar
chenych committed
40
            self.compute_score: Callable[[str, str], RewardScore] = r1v_compute_score
chenych's avatar
chenych committed
41
        else:
chenych's avatar
Update  
chenych committed
42
            raise NotImplementedError(f"Unknown score function {config.score_function}.")
chenych's avatar
chenych committed
43

chenych's avatar
Update  
chenych committed
44
    def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
chenych's avatar
chenych committed
45
        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
chenych's avatar
chenych committed
46
        reward_metrics = defaultdict(list)
chenych's avatar
chenych committed
47
48
49
        for i in range(len(data)):
            data_item = data[i]  # DataProtoItem
            response_ids = data_item.batch["responses"]
chenych's avatar
chenych committed
50
51
            response_mask = data_item.batch["response_mask"]
            valid_response_length = response_mask.sum()
chenych's avatar
chenych committed
52
53
            valid_response_ids = response_ids[:valid_response_length]

chenych's avatar
Update  
chenych committed
54
55
56
            response_str = self.tokenizer.decode(
                valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
            )
chenych's avatar
chenych committed
57
            ground_truth = data_item.non_tensor_batch["ground_truth"]
chenych's avatar
chenych committed
58
59

            score = self.compute_score(response_str, ground_truth)
chenych's avatar
chenych committed
60
61
62
            reward_tensor[i, valid_response_length - 1] = score["overall"]
            for key, value in score.items():
                reward_metrics[key].append(value)
chenych's avatar
chenych committed
63

chenych's avatar
chenych committed
64
        return reward_tensor, reward_metrics