function.py 3.33 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

chenych's avatar
update  
chenych committed
15
16
17
import importlib.util
import os
import sys
chenych's avatar
chenych committed
18
from collections import defaultdict
chenych's avatar
v0.3.0  
chenych committed
19
from dataclasses import dataclass
chenych's avatar
update  
chenych committed
20
21
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, TypedDict
chenych's avatar
chenych committed
22

chenych's avatar
chenych committed
23
24
25
import torch
from transformers import PreTrainedTokenizer

chenych's avatar
chenych committed
26
from ...protocol import DataProto
chenych's avatar
Update  
chenych committed
27
from .config import RewardConfig
chenych's avatar
chenych committed
28
29
30
31


class RewardScore(TypedDict):
    overall: float
chenych's avatar
update  
chenych committed
32
33
34
35
    format: Optional[float]
    accuracy: Optional[float]


chenych's avatar
v0.3.0  
chenych committed
36
ScoreFunction = Callable[[str, str], RewardScore]
chenych's avatar
update  
chenych committed
37
38


chenych's avatar
v0.3.0  
chenych committed
39
40
41
42
@dataclass
class FunctionRewardManager:
    config: RewardConfig
    tokenizer: PreTrainedTokenizer
chenych's avatar
chenych committed
43

chenych's avatar
v0.3.0  
chenych committed
44
45
46
47
    def __post_init__(self):
        """Load score function."""
        if self.config.score_function is None:
            raise ValueError("Score function is not provided.")
chenych's avatar
update  
chenych committed
48

chenych's avatar
v0.3.0  
chenych committed
49
50
        if not os.path.exists(self.config.score_function):
            raise FileNotFoundError(f"Score function file {self.config.score_function} not found.")
chenych's avatar
update  
chenych committed
51

chenych's avatar
v0.3.0  
chenych committed
52
        spec = importlib.util.spec_from_file_location("custom_score_fn", self.config.score_function)
chenych's avatar
update  
chenych committed
53
54
        module = importlib.util.module_from_spec(spec)
        try:
chenych's avatar
v0.3.0  
chenych committed
55
            sys.modules["custom_score_fn"] = module
chenych's avatar
update  
chenych committed
56
57
            spec.loader.exec_module(module)
        except Exception as e:
chenych's avatar
v0.3.0  
chenych committed
58
            raise RuntimeError(f"Failed to load score function: {e}")
chenych's avatar
chenych committed
59

chenych's avatar
v0.3.0  
chenych committed
60
61
        if not hasattr(module, self.config.score_function_name):
            raise AttributeError(f"Module {module} does not have function {self.config.score_function_name}.")
chenych's avatar
chenych committed
62

chenych's avatar
v0.3.0  
chenych committed
63
64
65
        score_fn: ScoreFunction = getattr(module, self.config.score_function_name)
        print(f"Using score function `{self.config.score_function_name}` from `{self.config.score_function}`.")
        self.score_fn = partial(score_fn, **self.config.score_function_kwargs)
chenych's avatar
chenych committed
66

chenych's avatar
v0.3.0  
chenych committed
67
    def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
chenych's avatar
chenych committed
68
        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
chenych's avatar
chenych committed
69
        reward_metrics = defaultdict(list)
chenych's avatar
chenych committed
70
        for i in range(len(data)):
chenych's avatar
v0.3.0  
chenych committed
71
72
73
74
75
76
            data_item = data[i]  # DataProtoItem
            response_ids = data_item.batch["responses"]
            response_mask = data_item.batch["response_mask"]
            valid_response_length = response_mask.sum()
            valid_response_ids = response_ids[:valid_response_length]

chenych's avatar
Update  
chenych committed
77
78
79
            response_str = self.tokenizer.decode(
                valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
            )
chenych's avatar
v0.3.0  
chenych committed
80
            ground_truth = data_item.non_tensor_batch["ground_truth"]
chenych's avatar
chenych committed
81

chenych's avatar
v0.3.0  
chenych committed
82
83
            score = self.score_fn(response_str, ground_truth)
            reward_tensor[i, valid_response_length - 1] = score["overall"]
chenych's avatar
chenych committed
84
85
            for key, value in score.items():
                reward_metrics[key].append(value)
chenych's avatar
chenych committed
86

chenych's avatar
chenych committed
87
        return reward_tensor, reward_metrics