base.py 4.49 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
5
6
from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import Any, List, Optional, Tuple, Union

import torch
import torch.nn as nn
7
from coati.models.base import Actor, get_base_model
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from .sampler import DistributedSampler

ModelOptimPair = Tuple[nn.Module, Optimizer]
ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]


class Strategy(ABC):
    """
        Base class for training strategies.
    """

    def __init__(self) -> None:
        super().__init__()
        self.setup_distributed()

    @abstractmethod
    def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
        pass

    @abstractmethod
    def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
        pass

    @abstractmethod
    def setup_distributed(self) -> None:
        pass

    @abstractmethod
    def setup_model(self, model: nn.Module) -> nn.Module:
        pass

    @abstractmethod
    def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer:
        pass

    @abstractmethod
    def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
        pass

    def model_init_context(self):
        return nullcontext()

    def prepare(
        self, *models_or_model_optim_pairs: ModelOrModelOptimPair
    ) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]:
        """Prepare models or model-optimizer-pairs based on each strategy.

        Example::
            >>> # when fine-tuning actor and critic
            >>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
            >>> # or when training reward model
            >>> (reward_model, reward_model_optim) = strategy.prepare((reward_model, reward_model_optim))
            >>> # or just inference
            >>> actor, critic = strategy.prepare(actor, critic)

        Returns:
            Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
        """

        def prepare_model(model: nn.Module):
            if isinstance(model, Actor):
74
75
                return Actor(self.setup_model(model.get_base_model()))
            return self.setup_model(model)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
76
77
78
79
80
81
82

        rets = []
        for arg in models_or_model_optim_pairs:
            if isinstance(arg, tuple):
                assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
                model, optimizer = arg
                model = prepare_model(model)
83
                optimizer = self.setup_optimizer(optimizer, get_base_model(model))
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
84
85
86
87
88
89
90
91
92
93
94
                rets.append((model, optimizer))
            elif isinstance(arg, nn.Module):
                rets.append(prepare_model(arg))
            else:
                raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')

        if len(rets) == 1:
            return rets[0]
        return rets

    @staticmethod
95
96
97
    def unwrap_model(model: nn.Module) -> nn.Module:
        """Get the unwrapped model from a wrapped model. Useful for getting original huggingface model.
        For Actor, it will unwrap `actor.model`.
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
98
99

        Args:
100
            model (nn.Module): the model to unwrap
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
101

102
103
        Returns:
            nn.Module: the original model (usually a huggingface model)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
104
        """
105
        return get_base_model(model)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
106
107

    @abstractmethod
108
    def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        pass

    @abstractmethod
    def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
        pass

    @abstractmethod
    def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
        pass

    @abstractmethod
    def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
        pass

    def setup_sampler(self, dataset) -> DistributedSampler:
        return DistributedSampler(dataset, 1, 0)
125
126
127
128
129
130
131
132

    @abstractmethod
    def save_pretrained(self,
                        model: nn.Module,
                        path: str,
                        only_rank0: bool = True,
                        tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
        pass