function.py 3.33 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

chenych's avatar
update  
chenych committed
15
16
17
import importlib.util
import os
import sys
chenych's avatar
chenych committed
18
from collections import defaultdict
chenych's avatar
update  
chenych committed
19
20
21
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, TypedDict
chenych's avatar
chenych committed
22

chenych's avatar
chenych committed
23
24
25
import torch
from transformers import PreTrainedTokenizer

chenych's avatar
chenych committed
26
from ...protocol import DataProto
chenych's avatar
Update  
chenych committed
27
from .config import RewardConfig
chenych's avatar
chenych committed
28
29
30
31


class RewardScore(TypedDict):
    overall: float
chenych's avatar
update  
chenych committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    format: Optional[float]
    accuracy: Optional[float]


ScoreFunction = Callable[[str, str], RewardScore]


@dataclass
class FunctionRewardManager:
    config: RewardConfig
    tokenizer: PreTrainedTokenizer

    def __post_init__(self):
        """Load score function."""
        if self.config.score_function is None:
            raise ValueError("Score function is not provided.")

        if not os.path.exists(self.config.score_function):
            raise FileNotFoundError(f"Score function file {self.config.score_function} not found.")

        spec = importlib.util.spec_from_file_location("custom_score_fn", self.config.score_function)
        module = importlib.util.module_from_spec(spec)
        try:
            sys.modules["custom_score_fn"] = module
            spec.loader.exec_module(module)
        except Exception as e:
            raise RuntimeError(f"Failed to load score function: {e}")
chenych's avatar
chenych committed
59

chenych's avatar
update  
chenych committed
60
61
        if not hasattr(module, self.config.score_function_name):
            raise AttributeError(f"Module {module} does not have function {self.config.score_function_name}.")
chenych's avatar
chenych committed
62

chenych's avatar
update  
chenych committed
63
64
65
        score_fn: ScoreFunction = getattr(module, self.config.score_function_name)
        print(f"Using score function `{self.config.score_function_name}` from `{self.config.score_function}`.")
        self.score_fn = partial(score_fn, **self.config.score_function_kwargs)
chenych's avatar
chenych committed
66

chenych's avatar
Update  
chenych committed
67
    def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
chenych's avatar
chenych committed
68
        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
chenych's avatar
chenych committed
69
        reward_metrics = defaultdict(list)
chenych's avatar
chenych committed
70
71
72
        for i in range(len(data)):
            data_item = data[i]  # DataProtoItem
            response_ids = data_item.batch["responses"]
chenych's avatar
chenych committed
73
74
            response_mask = data_item.batch["response_mask"]
            valid_response_length = response_mask.sum()
chenych's avatar
chenych committed
75
76
            valid_response_ids = response_ids[:valid_response_length]

chenych's avatar
Update  
chenych committed
77
78
79
            response_str = self.tokenizer.decode(
                valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
            )
chenych's avatar
chenych committed
80
            ground_truth = data_item.non_tensor_batch["ground_truth"]
chenych's avatar
chenych committed
81

chenych's avatar
update  
chenych committed
82
            score = self.score_fn(response_str, ground_truth)
chenych's avatar
chenych committed
83
84
85
            reward_tensor[i, valid_response_length - 1] = score["overall"]
            for key, value in score.items():
                reward_metrics[key].append(value)
chenych's avatar
chenych committed
86

chenych's avatar
chenych committed
87
        return reward_tensor, reward_metrics