ppo_utils.py 3.34 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
15
16
17
18
19
20
21
22
23
24
import json
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional

import torch
from transformers.integrations import is_deepspeed_zero3_enabled

from ...extras.packages import is_requests_available


chenych's avatar
chenych committed
25
26
27
28
if is_requests_available():
    import requests


Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
29
30
31
32
33
if TYPE_CHECKING:
    from transformers import PreTrainedModel
    from trl import AutoModelForCausalLMWithValueHead


luopl's avatar
luopl committed
34
def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch.Tensor"]:
chenych's avatar
chenych committed
35
36
37
    r"""
    Gets reward scores from the API server.
    """
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
38
39
40
41
42
43
44
45
    headers = {"Content-Type": "application/json"}
    payload = {"model": "model", "messages": messages}
    response = requests.post(server_url, json=payload, headers=headers)
    rewards = json.loads(response.text)["scores"]
    return torch.Tensor(rewards)


def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
chenych's avatar
chenych committed
46
47
48
49
    r"""
    Replaces the default/reward modules in the model. The model is already unwrapped.
    """
    v_head_layer = model.v_head.summary
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
50
51
52
    if is_deepspeed_zero3_enabled():
        import deepspeed  # type: ignore

chenych's avatar
chenych committed
53
        params = [v_head_layer.weight, v_head_layer.bias]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
54
55
56
57
        context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
    else:
        context_maybe_zero3 = nullcontext()

chenych's avatar
chenych committed
58
    model.pretrained_model.set_adapter(target)  # set the LoRA adapter to be active
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
59
60
    with context_maybe_zero3:
        if target == "reward":  # save default head temporarily
chenych's avatar
chenych committed
61
62
            setattr(model, "default_head_weight", v_head_layer.weight.data.detach().clone())
            setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone())
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
63

chenych's avatar
chenych committed
64
65
66
        device = v_head_layer.weight.device
        v_head_layer.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
        v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
67
68


luopl's avatar
luopl committed
69
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
chenych's avatar
chenych committed
70
71
72
    r"""
    Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
    """
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
73
74
75
76
77
78
79
80
81
    layer_norm_params = {}
    for name, param in model.named_parameters():
        if param.data.dtype == torch.float32:
            layer_norm_params[name] = param.data.detach().clone()
            param.data = param.data.to(model.config.torch_dtype)

    return layer_norm_params


luopl's avatar
luopl committed
82
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
chenych's avatar
chenych committed
83
84
85
    r"""
    Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
    """
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
86
87
88
    for name, param in model.named_parameters():
        if name in layernorm_params:
            param.data = layernorm_params[name]