ppo.py 10.9 KB
Newer Older
1
from typing import Any, Callable, Dict, List, Optional, Union
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
2
3
4
5

import torch
import torch.nn as nn
from coati.experience_maker import Experience, NaiveExperienceMaker
6
from coati.models.base import Actor, Critic, get_base_model
Hongxin Liu's avatar
Hongxin Liu committed
7
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
8
from coati.models.utils import calc_action_log_probs
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
9
from coati.replay_buffer import NaiveReplayBuffer
10
from torch import Tensor
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
11
from torch.optim import Optimizer
12
13
from torch.utils.data import DistributedSampler
from tqdm import tqdm
14
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
15

Hongxin Liu's avatar
Hongxin Liu committed
16
17
from colossalai.utils import get_current_device

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
18
19
from .base import Trainer
from .callbacks import Callback
20
from .strategies import ColossalAIStrategy, Strategy
Hongxin Liu's avatar
Hongxin Liu committed
21
from .utils import is_rank_0, to_device
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
22
23
24
25
26
27
28
29
30
31
32


class PPOTrainer(Trainer):
    """
        Trainer for PPO algorithm.

    Args:
        strategy (Strategy): the strategy to use for training
        actor (Actor): the actor model in ppo algorithm
        critic (Critic): the critic model in ppo algorithm
        reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
33
        initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
34
35
36
37
        actor_optim (Optimizer): the optimizer to use for actor model
        critic_optim (Optimizer): the optimizer to use for critic model
        kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
        train_batch_size (int, defaults to 8): the batch size to use for training
38
        buffer_limit (int, defaults to 0): the max_size limitation of replay buffer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
39
40
        buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
        eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
41
        vf_coef (float, defaults to 1.0): the coefficient of value loss
42
        ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
43
44
45
46
        value_clip (float, defaults to 0.4): the clip coefficient of value loss
        max_epochs (int, defaults to 1): the number of epochs of training process
        sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
        dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
Hongxin Liu's avatar
Hongxin Liu committed
47
        offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        callbacks (List[Callback], defaults to []): the callbacks to call during training process
        generate_kwargs (dict, optional): the kwargs to use while model generating
    """

    def __init__(self,
                 strategy: Strategy,
                 actor: Actor,
                 critic: Critic,
                 reward_model: nn.Module,
                 initial_model: Actor,
                 actor_optim: Optimizer,
                 critic_optim: Optimizer,
                 kl_coef: float = 0.1,
                 ptx_coef: float = 0.9,
                 train_batch_size: int = 8,
                 buffer_limit: int = 0,
                 buffer_cpu_offload: bool = True,
                 eps_clip: float = 0.2,
66
                 vf_coef: float = 1.0,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
67
68
69
70
                 value_clip: float = 0.4,
                 max_epochs: int = 1,
                 sample_replay_buffer: bool = False,
                 dataloader_pin_memory: bool = True,
Hongxin Liu's avatar
Hongxin Liu committed
71
                 offload_inference_models: bool = True,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
72
73
                 callbacks: List[Callback] = [],
                 **generate_kwargs) -> None:
74
75
76
77
78
        if isinstance(strategy, ColossalAIStrategy):
            from colossalai.booster.plugin import GeminiPlugin
            assert not (isinstance(strategy.plugin, GeminiPlugin) and offload_inference_models), \
                "GeminiPlugin is not compatible with manual model.to('cpu')"

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
79
80
81
        experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
        replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
        generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
Hongxin Liu's avatar
Hongxin Liu committed
82
        super().__init__(strategy, max_epochs, dataloader_pin_memory, callbacks, **generate_kwargs)
83
84
85
86

        self.experience_maker = experience_maker
        self.replay_buffer = replay_buffer
        self.sample_replay_buffer = sample_replay_buffer
Hongxin Liu's avatar
Hongxin Liu committed
87
        self.offload_inference_models = offload_inference_models
88

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
89
90
91
92
93
        self.actor = actor
        self.critic = critic

        self.actor_loss_fn = PolicyLoss(eps_clip)
        self.critic_loss_fn = ValueLoss(value_clip)
94
        self.vf_coef = vf_coef
Hongxin Liu's avatar
Hongxin Liu committed
95
        self.ptx_loss_fn = GPTLMLoss()
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
96
97
98
99
        self.ptx_coef = ptx_coef
        self.actor_optim = actor_optim
        self.critic_optim = critic_optim

Hongxin Liu's avatar
Hongxin Liu committed
100
101
        self.device = get_current_device()

102
103
104
105
106
107
108
109
110
111
112
    def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
        if isinstance(inputs, Tensor):
            return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
        elif isinstance(inputs, dict):
            return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
        else:
            raise ValueError(f'Unsupported input type "{type(inputs)}"')

    def _learn(self):
        # replay buffer may be empty at first, we should rebuild at each training
        if not self.sample_replay_buffer:
113
114
            # HACK(cwher): according to the design of boost API, dataloader should also be boosted,
            #  but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
115
            dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
116
        if self.sample_replay_buffer:
117
            pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
118
119
            for _ in pbar:
                experience = self.replay_buffer.sample()
Hongxin Liu's avatar
Hongxin Liu committed
120
                experience.to_device(self.device)
121
122
123
124
125
126
127
                metrics = self.training_step(experience)
                pbar.set_postfix(metrics)
        else:
            for epoch in range(self.max_epochs):
                self._on_learn_epoch_start(epoch)
                if isinstance(dataloader.sampler, DistributedSampler):
                    dataloader.sampler.set_epoch(epoch)
128
                pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
129
130
                for experience in pbar:
                    self._on_learn_batch_start()
Hongxin Liu's avatar
Hongxin Liu committed
131
                    experience.to_device(self.device)
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
                    metrics = self.training_step(experience)
                    self._on_learn_batch_end(metrics, experience)
                    pbar.set_postfix(metrics)
                self._on_learn_epoch_end(epoch)

    def fit(self,
            prompt_dataloader,
            pretrain_dataloader,
            num_episodes: int = 50000,
            max_timesteps: int = 500,
            update_timesteps: int = 5000) -> None:
        time = 0
        self.pretrain_dataloader = pretrain_dataloader
        self.prompt_dataloader = prompt_dataloader
        self._on_fit_start()
        for episode in range(num_episodes):
            self._on_episode_start(episode)
            for timestep in tqdm(range(max_timesteps),
                                 desc=f'Episode [{episode+1}/{num_episodes}]',
                                 disable=not is_rank_0()):
                time += 1
                prompts = next(iter(self.prompt_dataloader))
                self._on_make_experience_start()
Hongxin Liu's avatar
Hongxin Liu committed
155
156
157
158
                if self.offload_inference_models:
                    # TODO(ver217): this may be controlled by strategy if they are prepared by strategy
                    self.experience_maker.initial_model.to(self.device)
                    self.experience_maker.reward_model.to(self.device)
159
160
161
162
                experience = self._make_experience(prompts)
                self._on_make_experience_end(experience)
                self.replay_buffer.append(experience)
                if time % update_timesteps == 0:
Hongxin Liu's avatar
Hongxin Liu committed
163
164
165
                    if self.offload_inference_models:
                        self.experience_maker.initial_model.to('cpu')
                        self.experience_maker.reward_model.to('cpu')
166
167
168
169
170
                    self._learn()
                    self.replay_buffer.clear()
            self._on_episode_end(episode)
        self._on_fit_end()

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
171
172
173
174
175
    def training_step(self, experience: Experience) -> Dict[str, float]:
        self.actor.train()
        self.critic.train()
        # policy loss
        num_actions = experience.action_mask.size(1)
176
177
        actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
        action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
178
179
180
181
182
183
184
        actor_loss = self.actor_loss_fn(action_log_probs,
                                        experience.action_log_probs,
                                        experience.advantages,
                                        action_mask=experience.action_mask)

        # ptx loss
        if self.ptx_coef != 0:
185
            batch = next(iter(self.pretrain_dataloader))
Hongxin Liu's avatar
Hongxin Liu committed
186
            batch = to_device(batch, self.device)
187
188
            ptx_log_probs = self.actor(batch['input_ids'],
                                       attention_mask=batch['attention_mask'])['logits']
Hongxin Liu's avatar
Hongxin Liu committed
189
            ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
            actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)

        self.strategy.backward(actor_loss, self.actor, self.actor_optim)
        self.strategy.optimizer_step(self.actor_optim)
        self.actor_optim.zero_grad()

        # value loss
        values = self.critic(experience.sequences,
                             action_mask=experience.action_mask,
                             attention_mask=experience.attention_mask)
        critic_loss = self.critic_loss_fn(values,
                                          experience.values,
                                          experience.reward,
                                          action_mask=experience.action_mask)
204
        critic_loss = critic_loss * self.vf_coef
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
205
206
207
208
209
        self.strategy.backward(critic_loss, self.critic, self.critic_optim)
        self.strategy.optimizer_step(self.critic_optim)
        self.critic_optim.zero_grad()

        return {'reward': experience.reward.mean().item()}
210

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
211

212
213
214
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
    unwrapper_model = strategy.unwrap_model(actor)
    hf_model = get_base_model(unwrapper_model)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
215
216
    new_kwargs = {**generate_kwargs}
    # use huggingface models method directly
217
218
    if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
        new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
219

220
221
    if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
        new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
222
223

    return new_kwargs