"tests/test_infer/test_bloom_infer.py" did not exist on "eedaa3e1ef991d9f9a274d10c046877ba2b10467"
ppo.py 8.91 KB
Newer Older
1
from typing import Dict, List
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
2
3
4

import torch.nn as nn
from coati.experience_maker import Experience, NaiveExperienceMaker
5
from coati.models.base import Actor, Critic, get_base_model
Hongxin Liu's avatar
Hongxin Liu committed
6
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
7
from coati.models.utils import calc_action_log_probs
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
8
from coati.replay_buffer import NaiveReplayBuffer
9
from torch import Tensor
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
10
from torch.optim import Optimizer
11
from torch.utils.data import DataLoader, DistributedSampler
12
from tqdm import tqdm
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
13

Hongxin Liu's avatar
Hongxin Liu committed
14
15
from colossalai.utils import get_current_device

16
from .base import OnPolicyTrainer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
17
from .callbacks import Callback
18
from .strategies import ColossalAIStrategy, Strategy
Hongxin Liu's avatar
Hongxin Liu committed
19
from .utils import is_rank_0, to_device
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
20
21


22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
    unwrapper_model = strategy.unwrap_model(actor)
    hf_model = get_base_model(unwrapper_model)
    new_kwargs = {**generate_kwargs}
    # use huggingface models method directly
    if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
        new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation

    if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
        new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation

    return new_kwargs


class PPOTrainer(OnPolicyTrainer):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
37
38
39
40
41
42
43
44
    """
        Trainer for PPO algorithm.

    Args:
        strategy (Strategy): the strategy to use for training
        actor (Actor): the actor model in ppo algorithm
        critic (Critic): the critic model in ppo algorithm
        reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
45
        initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
46
47
48
49
        actor_optim (Optimizer): the optimizer to use for actor model
        critic_optim (Optimizer): the optimizer to use for critic model
        kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
        train_batch_size (int, defaults to 8): the batch size to use for training
50
51
        buffer_limit (int, defaults to 0): the max_size limitation of buffer
        buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
52
        eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
53
        vf_coef (float, defaults to 1.0): the coefficient of value loss
54
        ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
55
        value_clip (float, defaults to 0.4): the clip coefficient of value loss
56
        sample_buffer (bool, defaults to False): whether to sample from buffer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
57
        dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
Hongxin Liu's avatar
Hongxin Liu committed
58
        offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        callbacks (List[Callback], defaults to []): the callbacks to call during training process
        generate_kwargs (dict, optional): the kwargs to use while model generating
    """

    def __init__(self,
                 strategy: Strategy,
                 actor: Actor,
                 critic: Critic,
                 reward_model: nn.Module,
                 initial_model: Actor,
                 actor_optim: Optimizer,
                 critic_optim: Optimizer,
                 kl_coef: float = 0.1,
                 ptx_coef: float = 0.9,
                 train_batch_size: int = 8,
                 buffer_limit: int = 0,
                 buffer_cpu_offload: bool = True,
                 eps_clip: float = 0.2,
77
                 vf_coef: float = 1.0,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
78
                 value_clip: float = 0.4,
79
                 sample_buffer: bool = False,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
80
                 dataloader_pin_memory: bool = True,
Hongxin Liu's avatar
Hongxin Liu committed
81
                 offload_inference_models: bool = True,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
82
                 callbacks: List[Callback] = [],
83
84
                 **generate_kwargs
                 ) -> None:
85
86
87
88
89
        if isinstance(strategy, ColossalAIStrategy):
            from colossalai.booster.plugin import GeminiPlugin
            assert not (isinstance(strategy.plugin, GeminiPlugin) and offload_inference_models), \
                "GeminiPlugin is not compatible with manual model.to('cpu')"

90
91
92
93
94
95
        buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
        super().__init__(
            strategy, buffer,
            sample_buffer, dataloader_pin_memory,
            callbacks
        )
96

97
98
        self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
        self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
Hongxin Liu's avatar
Hongxin Liu committed
99
        self.offload_inference_models = offload_inference_models
100

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
101
102
103
104
105
        self.actor = actor
        self.critic = critic

        self.actor_loss_fn = PolicyLoss(eps_clip)
        self.critic_loss_fn = ValueLoss(value_clip)
106
        self.vf_coef = vf_coef
Hongxin Liu's avatar
Hongxin Liu committed
107
        self.ptx_loss_fn = GPTLMLoss()
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
108
109
110
111
        self.ptx_coef = ptx_coef
        self.actor_optim = actor_optim
        self.critic_optim = critic_optim

Hongxin Liu's avatar
Hongxin Liu committed
112
113
        self.device = get_current_device()

114
115
116
117
118
119
120
121
122
123
    def _make_experience(self, collect_step: int) -> Experience:
        prompts = self.prompt_dataloader.next()
        if self.offload_inference_models:
            # TODO(ver217): this may be controlled by strategy if they are prepared by strategy
            self.experience_maker.initial_model.to(self.device)
            self.experience_maker.reward_model.to(self.device)
        if isinstance(prompts, Tensor):
            return self.experience_maker.make_experience(prompts, **self.generate_kwargs)
        elif isinstance(prompts, dict):
            return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
124
        else:
125
126
127
            raise ValueError(f'Unsupported input type "{type(prompts)}"')

    def _training_step(self, experience: Experience) -> Dict[str, float]:
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
128
129
130
131
        self.actor.train()
        self.critic.train()
        # policy loss
        num_actions = experience.action_mask.size(1)
132
133
        actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
        action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
134
135
136
137
138
139
140
        actor_loss = self.actor_loss_fn(action_log_probs,
                                        experience.action_log_probs,
                                        experience.advantages,
                                        action_mask=experience.action_mask)

        # ptx loss
        if self.ptx_coef != 0:
141
            batch = self.pretrain_dataloader.next()
Hongxin Liu's avatar
Hongxin Liu committed
142
            batch = to_device(batch, self.device)
143
144
            ptx_log_probs = self.actor(batch['input_ids'],
                                       attention_mask=batch['attention_mask'])['logits']
Hongxin Liu's avatar
Hongxin Liu committed
145
            ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
            actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)

        self.strategy.backward(actor_loss, self.actor, self.actor_optim)
        self.strategy.optimizer_step(self.actor_optim)
        self.actor_optim.zero_grad()

        # value loss
        values = self.critic(experience.sequences,
                             action_mask=experience.action_mask,
                             attention_mask=experience.attention_mask)
        critic_loss = self.critic_loss_fn(values,
                                          experience.values,
                                          experience.reward,
                                          action_mask=experience.action_mask)
160
        critic_loss = critic_loss * self.vf_coef
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
161
162
163
164
165
        self.strategy.backward(critic_loss, self.critic, self.critic_optim)
        self.strategy.optimizer_step(self.critic_optim)
        self.critic_optim.zero_grad()

        return {'reward': experience.reward.mean().item()}
166

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    def _learn(self, update_step: int):
        if self.offload_inference_models:
            self.experience_maker.initial_model.to('cpu')
            self.experience_maker.reward_model.to('cpu')

        # buffer may be empty at first, we should rebuild at each training
        if self.sample_buffer:
            experience = self.buffer.sample()
            self._on_learn_batch_start()
            experience.to_device(self.device)
            metrics = self._training_step(experience)
            self._on_learn_batch_end(metrics, experience)
        else:
            if isinstance(self.dataloader.sampler, DistributedSampler):
                self.dataloader.sampler.set_epoch(update_step)
            pbar = tqdm(
                self.dataloader,
                desc=f'Train epoch [{update_step + 1}]',
                disable=not is_rank_0()
            )
            for experience in pbar:
                self._on_learn_batch_start()
                experience.to_device(self.device)
                metrics = self._training_step(experience)
                self._on_learn_batch_end(metrics, experience)
                pbar.set_postfix(metrics)