detached_trainer_ppo.py 8.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from typing import Any, Callable, Dict, List, Optional
import torch
from torch.optim import Adam

from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic
from coati.models.generation_utils import update_model_kwargs_fn
from coati.models.loss import PolicyLoss, ValueLoss
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy
from coati.trainer.callbacks import Callback

from colossalai.nn.optimizer import HybridAdam

import ray


from .utils import is_rank_0, get_cuda_actor_critic_from_args, get_strategy_from_args, set_dist_env
from .detached_trainer_base import DetachedTrainer


@ray.remote(concurrency_groups={"buffer_length": 1, "buffer_append":1, "buffer_sample":1,"model_io": 1, "compute": 1})
class DetachedPPOTrainer(DetachedTrainer):
    '''
        Detached Trainer for PPO algorithm
    Args:
        strategy (Strategy): the strategy to use for training
        model (str) : for actor / critic init
        pretrained (str) : for actor / critic init
        lora_rank (int) : for actor / critic init
        train_batch_size (int, defaults to 8): the batch size to use for training
        train_batch_size (int, defaults to 8): the batch size to use for training
        buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer
        buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
        eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
        value_clip (float, defaults to 0.4): the clip coefficient of value loss
        experience_batch_size (int, defaults to 8): the batch size to use for experience generation
        max_epochs (int, defaults to 1): the number of epochs of training process
        dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
        callbacks (List[Callback], defaults to []): the callbacks to call during training process
        generate_kwargs (dict, optional): the kwargs to use while model generating
    '''

    def __init__(self,
                 experience_maker_holder_name_list: List[str],
                 strategy: str,
                 model: str,
                 env_info: Dict[str, str] = None,
                 pretrained: str = None,
                 lora_rank: int = 0,
                 train_batch_size: int = 8,
                 buffer_limit: int = 0,
                 buffer_cpu_offload: bool = True,
                 eps_clip: float = 0.2,
                 value_clip: float = 0.4,
                 experience_batch_size: int = 8,
                 max_epochs: int = 10,
                 dataloader_pin_memory: bool = True,
                 callbacks: List[Callback] = [],
                 **generate_kwargs) -> None:
        # set environment variables
        if env_info:
            set_dist_env(env_info=env_info)
        # configure strategy
        self.strategy = get_strategy_from_args(strategy)
        # configure models, loss and optimizers
        with self.strategy.model_init_context():
            self.actor, self.critic = get_cuda_actor_critic_from_args(model, pretrained, lora_rank)

        if strategy != 'colossalai_gemini':
            self.actor.to(torch.float16).to(torch.cuda.current_device())
            self.critic.to(torch.float16).to(torch.cuda.current_device())

        if strategy.startswith('colossalai'):
            self.actor_optim = HybridAdam(self.actor.parameters(), lr=5e-6)
            self.critic_optim = HybridAdam(self.critic.parameters(), lr=5e-6)
        else:
            self.actor_optim = Adam(self.actor.parameters(), lr=5e-6)
            self.critic_optim = Adam(self.critic.parameters(), lr=5e-6)

        (self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
            self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
        generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor)

        self.actor_loss_fn = PolicyLoss(eps_clip)
        self.critic_loss_fn = ValueLoss(value_clip)

        super().__init__(experience_maker_holder_name_list,
                         train_batch_size=train_batch_size,
                         buffer_limit=buffer_limit,
                         buffer_cpu_offload=buffer_cpu_offload,
                         experience_batch_size=experience_batch_size,
                         max_epochs=max_epochs,
                         dataloader_pin_memory=dataloader_pin_memory,
                         callbacks=callbacks,
                         **generate_kwargs)

    @ray.method(concurrency_group="model_io")
    def _update_remote_makers(self):
        # TODO: balance duties
        if is_rank_0():
            self.update_target_holder_list(self.target_holder_name_list)
            for target_holder in self.target_holder_list:
                # TODO: reduce malloc
                with torch.no_grad():
                    ray.get(target_holder.update_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic()))
                    
    @ray.method(concurrency_group="model_io")
    def initialize_remote_makers(self):
        # TODO: balance duties
        if is_rank_0():
            self.update_target_holder_list(self.target_holder_name_list)
            for target_holder in self.target_holder_list:
                # TODO: reduce malloc
                with torch.no_grad():
                    ray.get(target_holder.initialize_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic()))

    @ray.method(concurrency_group="compute")
    def training_step(self, experience: Experience) -> Dict[str, float]:
        self.actor.train()
        self.critic.train()

        experience.to_device(torch.cuda.current_device())
        num_actions = experience.action_mask.size(1)
        action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
        actor_loss = self.actor_loss_fn(action_log_probs,
                                        experience.action_log_probs,
                                        experience.advantages,
                                        action_mask=experience.action_mask)
        self.strategy.backward(actor_loss, self.actor, self.actor_optim)
        self.strategy.optimizer_step(self.actor_optim)
        self.actor_optim.zero_grad()

        values = self.critic(experience.sequences,
                             action_mask=experience.action_mask,
                             attention_mask=experience.attention_mask)
        critic_loss = self.critic_loss_fn(values,
                                          experience.values,
                                          experience.reward,
                                          action_mask=experience.action_mask)

        self.strategy.backward(critic_loss, self.critic, self.critic_optim)
        self.strategy.optimizer_step(self.critic_optim)
        self.critic_optim.zero_grad()
        return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}

    def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
        self.strategy.save_model(self.actor, path, only_rank0)

    def strategy_save_critic(self, path: str, only_rank0: bool = False) -> None:
        self.strategy.save_model(self.critic, path, only_rank0)

    def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None:
        self.strategy.save_optimizer(self.actor_optim, path, only_rank0)

    def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None:
        self.strategy.save_optimizer(self.critic_optim, path, only_rank0)

    def _get_unwrapped_actor(self):
        if False:
            pass
        elif isinstance(self.strategy, ColossalAIStrategy):
            ret = Actor(self.strategy._unwrap_model(self.actor))
            return ret
        elif isinstance(self.strategy, DDPStrategy):
            return Actor(self.strategy._unwrap_actor(self.actor))
        elif isinstance(self.strategy, NaiveStrategy):
            return self.actor

    def _get_unwrapped_critic(self):
        if False:
            pass
        elif isinstance(self.strategy, ColossalAIStrategy):
            ret = self.strategy._unwrap_model(self.critic)
            return ret
        elif isinstance(self.strategy, DDPStrategy):
            return self.critic.module
        elif isinstance(self.strategy, NaiveStrategy):
            return self.critic


def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
    origin_model = strategy._unwrap_actor(actor)
    new_kwargs = {**generate_kwargs}
    # use huggingface models method directly
    if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
        new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation

    if 'update_model_kwargs_fn' not in generate_kwargs:
        new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn

    return new_kwargs