ppo.py 9.13 KB
Newer Older
1
from typing import Dict, List, Optional
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
2

3
from coati.experience_buffer import NaiveExperienceBuffer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
4
from coati.experience_maker import Experience, NaiveExperienceMaker
5
from coati.models.base import Actor, Critic, RewardModel, get_base_model
Hongxin Liu's avatar
Hongxin Liu committed
6
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
7
from coati.models.utils import calc_action_log_probs
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
8
from torch.optim import Optimizer
9
from torch.utils.data import DataLoader, DistributedSampler
10
from tqdm import tqdm
11
from transformers import PreTrainedTokenizerBase
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
12

13
from colossalai.accelerator import get_accelerator
Hongxin Liu's avatar
Hongxin Liu committed
14

15
from .base import OnPolicyTrainer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
16
from .callbacks import Callback
17
from .strategies import GeminiStrategy, Strategy
18
from .utils import CycledDataLoader, is_rank_0, to_device
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
19
20


21
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
22
23
    unwrapped_model = strategy.unwrap_model(actor)
    hf_model = get_base_model(unwrapped_model)
24
25
    new_kwargs = {**generate_kwargs}
    # use huggingface models method directly
26
27
    if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
        new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation
28

29
30
    if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"):
        new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation
31
32
33
34
35

    return new_kwargs


class PPOTrainer(OnPolicyTrainer):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
36
37
38
39
40
41
42
    """
        Trainer for PPO algorithm.

    Args:
        strategy (Strategy): the strategy to use for training
        actor (Actor): the actor model in ppo algorithm
        critic (Critic): the critic model in ppo algorithm
43
        reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
44
        initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
45
46
47
48
        actor_optim (Optimizer): the optimizer to use for actor model
        critic_optim (Optimizer): the optimizer to use for critic model
        kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
        train_batch_size (int, defaults to 8): the batch size to use for training
49
50
        buffer_limit (int, defaults to 0): the max_size limitation of buffer
        buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
51
        eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
52
        vf_coef (float, defaults to 1.0): the coefficient of value loss
53
        ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
54
        value_clip (float, defaults to 0.4): the clip coefficient of value loss
55
        sample_buffer (bool, defaults to False): whether to sample from buffer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
56
        dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
Hongxin Liu's avatar
Hongxin Liu committed
57
        offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
58
59
60
61
        callbacks (List[Callback], defaults to []): the callbacks to call during training process
        generate_kwargs (dict, optional): the kwargs to use while model generating
    """

62
63
64
65
66
    def __init__(
        self,
        strategy: Strategy,
        actor: Actor,
        critic: Critic,
67
        reward_model: RewardModel,
68
69
70
        initial_model: Actor,
        actor_optim: Optimizer,
        critic_optim: Optimizer,
71
        tokenizer: PreTrainedTokenizerBase,
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        kl_coef: float = 0.1,
        ptx_coef: float = 0.9,
        train_batch_size: int = 8,
        buffer_limit: int = 0,
        buffer_cpu_offload: bool = True,
        eps_clip: float = 0.2,
        vf_coef: float = 1.0,
        value_clip: float = 0.4,
        sample_buffer: bool = False,
        dataloader_pin_memory: bool = True,
        offload_inference_models: bool = True,
        callbacks: List[Callback] = [],
        **generate_kwargs,
    ) -> None:
86
        if isinstance(strategy, GeminiStrategy):
87
            assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
88

89
        data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
90
        super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
91

92
        self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
93
        self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer, kl_coef)
94

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
95
96
        self.actor = actor
        self.critic = critic
97
        self.tokenizer = tokenizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
98
99
100

        self.actor_loss_fn = PolicyLoss(eps_clip)
        self.critic_loss_fn = ValueLoss(value_clip)
101
        self.vf_coef = vf_coef
Hongxin Liu's avatar
Hongxin Liu committed
102
        self.ptx_loss_fn = GPTLMLoss()
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
103
104
105
106
        self.ptx_coef = ptx_coef
        self.actor_optim = actor_optim
        self.critic_optim = critic_optim

107
        self.offload_inference_models = offload_inference_models
108
        self.device = get_accelerator().get_current_device()
Hongxin Liu's avatar
Hongxin Liu committed
109

110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    def _before_fit(
        self,
        prompt_dataloader: DataLoader,
        pretrain_dataloader: DataLoader,
        log_dir: Optional[str] = None,
        use_wandb: bool = False,
    ):
        """
        Args:
            prompt_dataloader (DataLoader): the dataloader to use for prompt data
            pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
        """
        self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
        self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)

        self.writer = None
        if use_wandb and is_rank_0():
            assert log_dir is not None, "log_dir must be provided when use_wandb is True"
            import wandb

            wandb.init(project="Coati-ppo", sync_tensorboard=True)
        if log_dir is not None and is_rank_0():
            import os
            import time

            from torch.utils.tensorboard import SummaryWriter

            log_dir = os.path.join(log_dir, "ppo")
            log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
            self.writer = SummaryWriter(log_dir=log_dir)

141
142
143
144
145
146
    def _make_experience(self, collect_step: int) -> Experience:
        prompts = self.prompt_dataloader.next()
        if self.offload_inference_models:
            # TODO(ver217): this may be controlled by strategy if they are prepared by strategy
            self.experience_maker.initial_model.to(self.device)
            self.experience_maker.reward_model.to(self.device)
147
148
        assert isinstance(prompts, dict), f'Unsupported input type "{type(prompts)}"'
        return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
149

150
    def _training_step(self, experience: Experience):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
151
152
153
        self.actor.train()
        self.critic.train()
        # policy loss
154
155
156
        num_actions = experience.action_log_probs.size(1)
        actor_logits = self.actor(experience.sequences, experience.attention_mask)["logits"]
        action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
157
158
159
        actor_loss = self.actor_loss_fn(
            action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
        )
160
161
        actor_loss = (1 - self.ptx_coef) * actor_loss
        self.strategy.backward(actor_loss, self.actor, self.actor_optim)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
162
163
164

        # ptx loss
        if self.ptx_coef != 0:
165
            batch = self.pretrain_dataloader.next()
Hongxin Liu's avatar
Hongxin Liu committed
166
            batch = to_device(batch, self.device)
167
168
169
            ptx_log_probs = self.actor(batch["input_ids"], batch["attention_mask"])["logits"]
            ptx_loss = self.ptx_coef * self.ptx_loss_fn(ptx_log_probs, batch["labels"])
            self.strategy.backward(ptx_loss, self.actor, self.actor_optim)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
170
171
172
173
174

        self.strategy.optimizer_step(self.actor_optim)
        self.actor_optim.zero_grad()

        # value loss
175
176
        values = self.critic(experience.sequences, attention_mask=experience.attention_mask)
        critic_loss = self.critic_loss_fn(values, experience.values, experience.reward)
177
        critic_loss = critic_loss * self.vf_coef
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
178
179
180
181
        self.strategy.backward(critic_loss, self.critic, self.critic_optim)
        self.strategy.optimizer_step(self.critic_optim)
        self.critic_optim.zero_grad()

182
183
    def _learn(self, update_step: int):
        if self.offload_inference_models:
184
185
            self.experience_maker.initial_model.to("cpu")
            self.experience_maker.reward_model.to("cpu")
186
187
188

        # buffer may be empty at first, we should rebuild at each training
        if self.sample_buffer:
189
            experience = self.data_buffer.sample()
190
191
            self._on_learn_batch_start()
            experience.to_device(self.device)
192
193
            self._training_step(experience)
            self._on_learn_batch_end(experience)
194
195
196
        else:
            if isinstance(self.dataloader.sampler, DistributedSampler):
                self.dataloader.sampler.set_epoch(update_step)
197
            pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
198
199
200
            for experience in pbar:
                self._on_learn_batch_start()
                experience.to_device(self.device)
201
202
                self._training_step(experience)
                self._on_learn_batch_end(experience)