save_and_load.py 991 Bytes
Newer Older
zhaoying1's avatar
zhaoying1 committed
1
2
import os
import torch
3
from transformers.trainer import WEIGHTS_NAME
zhaoying1's avatar
zhaoying1 committed
4
5
6
7
8
9
10
11

from llmtuner.extras.logging import get_logger


logger = get_logger(__name__)


def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
12
13
    vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
    if not os.path.exists(vhead_file):
zhaoying1's avatar
zhaoying1 committed
14
15
        logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
        return False
16
17
18
19
20
    vhead_params = torch.load(vhead_file, map_location="cpu")
    model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
    model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
    model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
    model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
zhaoying1's avatar
zhaoying1 committed
21
    return True