__init__.py 750 Bytes
Newer Older
1
2
from typing import Union

3
4
import torch.nn as nn

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
5
6
7
8
from .actor import Actor
from .critic import Critic
from .reward_model import RewardModel

9

10
def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
11
    """Get the base model of our wrapper classes.
12
    For Actor, Critic and RewardModel, return ``model.model``,
13
    it's usually a ``transformers.PreTrainedModel``.
14
15
16
17
18
19
20

    Args:
        model (nn.Module): model to get base model from

    Returns:
        nn.Module: the base model
    """
21
22
23
    assert isinstance(
        model, (Actor, Critic, RewardModel)
    ), f"Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first."
24
    return model.model
25
26


27
__all__ = ["Actor", "Critic", "RewardModel", "get_base_model"]