detached_trainer_ppo.py 8.55 KB
Newer Older
1
from typing import Callable, Dict, List, Tuple
2

3
4
import ray
import torch
5
from coati.experience_maker import Experience
6
7
from coati.models.base import Actor, Critic
from coati.models.loss import PolicyLoss, ValueLoss
8
from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
9
from torch.optim import Adam
10
11
12

from colossalai.nn.optimizer import HybridAdam

13
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
14
from .detached_trainer_base import DetachedTrainer
15
from .lora_constructor import LoRAConstructor
16
from .utils import get_model_numel, get_rank, set_dist_env, state_dict_to
17
18


19
20
21
@ray.remote(
    concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1}
)
22
class DetachedPPOTrainer(DetachedTrainer):
23
    """
24
25
26
27
28
29
30
31
        Detached Trainer for PPO algorithm
    Args:
        strategy (Strategy): the strategy to use for training
        model (str) : for actor / critic init
        pretrained (str) : for actor / critic init
        lora_rank (int) : for actor / critic init
        train_batch_size (int, defaults to 8): the batch size to use for training
        train_batch_size (int, defaults to 8): the batch size to use for training
32
        buffer_limit (int, defaults to 0): the max_size limitation of replay buffer
33
34
35
36
37
38
39
40
        buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
        eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
        value_clip (float, defaults to 0.4): the clip coefficient of value loss
        experience_batch_size (int, defaults to 8): the batch size to use for experience generation
        max_epochs (int, defaults to 1): the number of epochs of training process
        dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
        callbacks (List[Callback], defaults to []): the callbacks to call during training process
        generate_kwargs (dict, optional): the kwargs to use while model generating
41
    """
42

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    def __init__(
        self,
        experience_maker_holder_name_list: List[str],
        strategy_fn: Callable[[], Strategy],
        model_fn: Callable[[], Tuple[Actor, Critic]],
        env_info: Dict[str, str] = None,
        train_batch_size: int = 8,
        buffer_limit: int = 0,
        eps_clip: float = 0.2,
        value_clip: float = 0.4,
        dataloader_pin_memory: bool = True,
        callbacks: List[TrainerCallback] = [],
        eval_performance: bool = False,
        debug: bool = False,
        update_lora_weights: bool = False,
    ) -> None:
59
60
61
62
        # set environment variables
        if env_info:
            set_dist_env(env_info=env_info)
        # configure strategy
63
        self.strategy = strategy_fn()
64
65
        # configure models, loss and optimizers
        with self.strategy.model_init_context():
66
            self.actor, self.critic = model_fn()
67

68
69
70
71
72
        if eval_performance:
            actor_numel = get_model_numel(self.actor)
            critic_numel = get_model_numel(self.critic)
            evaluator = TrainerPerformanceEvaluator(actor_numel, critic_numel)
            callbacks = callbacks + [evaluator]
73

74
        if isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy)):
75
76
            self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7)
            self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7)
77
        else:
78
79
            self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
            self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
80

81
82
83
        (self.actor, self.actor_optim), (self.critic, self.critic_optim) = self.strategy.prepare(
            (self.actor, self.actor_optim), (self.critic, self.critic_optim)
        )
84

85
        # configure trainer
86
87
88
        self.actor_loss_fn = PolicyLoss(eps_clip)
        self.critic_loss_fn = ValueLoss(value_clip)

89
90
91
92
93
94
95
96
        super().__init__(
            experience_maker_holder_name_list,
            train_batch_size=train_batch_size,
            buffer_limit=buffer_limit,
            dataloader_pin_memory=dataloader_pin_memory,
            callbacks=callbacks,
            debug=debug,
        )
97
        if self._debug:
98
            print(f"[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}")
99
100

        self._update_lora_weights = update_lora_weights
101
102

    @ray.method(concurrency_group="model_io")
103
104
    @torch.no_grad()
    def _update_remote_makers(self, fully_update: bool = False, **config):
105
        # TODO: balance duties
106
        if not fully_update:
107
            config["requires_grad_only"] = True
108
109
110
111
112
113
114
115
116
117
        self.update_target_holder_list()
        # mark start, ensure order
        tasks = []
        for target_holder in self.target_holder_list:
            tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update))
        ray.get(tasks)
        # sending loop
        tasks = []

        for state_dict_shard in self._get_model_state_dict_shard(self.actor, fully_update=fully_update, **config):
118
            for target_holder in self.target_holder_list:
119
120
121
122
                tasks.append(
                    target_holder.update_experience_maker.remote(
                        new_actor_state_dict=state_dict_shard,
                        new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
123
124
125
                        fully_update=fully_update,
                    )
                )
126
127
        # sending loop
        for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
128
            for target_holder in self.target_holder_list:
129
130
131
132
                tasks.append(
                    target_holder.update_experience_maker.remote(
                        new_critic_state_dict=state_dict_shard,
                        new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
133
134
135
                        fully_update=fully_update,
                    )
                )
136
137
138
139
        ray.get(tasks)
        # mark end
        for target_holder in self.target_holder_list:
            target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update)
140
141
142
143
144
145
146
147

    @ray.method(concurrency_group="compute")
    def training_step(self, experience: Experience) -> Dict[str, float]:
        self.actor.train()
        self.critic.train()

        num_actions = experience.action_mask.size(1)
        action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
148
149
150
        actor_loss = self.actor_loss_fn(
            action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
        )
151
152
153
154
        self.strategy.backward(actor_loss, self.actor, self.actor_optim)
        self.strategy.optimizer_step(self.actor_optim)
        self.actor_optim.zero_grad()

155
156
157
158
159
160
        values = self.critic(
            experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
        )
        critic_loss = self.critic_loss_fn(
            values, experience.values, experience.reward, action_mask=experience.action_mask
        )
161
162
163
164

        self.strategy.backward(critic_loss, self.critic, self.critic_optim)
        self.strategy.optimizer_step(self.critic_optim)
        self.critic_optim.zero_grad()
165
        return {"actor_loss": actor_loss.item(), "critic_loss": critic_loss.item()}
166
167
168
169
170
171
172
173
174
175
176
177
178

    def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
        self.strategy.save_model(self.actor, path, only_rank0)

    def strategy_save_critic(self, path: str, only_rank0: bool = False) -> None:
        self.strategy.save_model(self.critic, path, only_rank0)

    def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None:
        self.strategy.save_optimizer(self.actor_optim, path, only_rank0)

    def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None:
        self.strategy.save_optimizer(self.critic_optim, path, only_rank0)

179
180
181
182
183
184
185
186
187
188
189
190
191
    def _get_model_state_dict_shard(self, model: torch.nn.Module, fully_update=False, **config):
        for state_dict in self.strategy.get_model_state_dict_shard(model, **config):
            if not self._update_lora_weights or fully_update:
                yield state_dict_to(state_dict)
            else:
                state_dict_lora, _ = LoRAConstructor.filter_state_dict_lora(state_dict)
                yield state_dict_to(state_dict_lora)

    def _get_model_lora_config_dict(self, model: torch.nn.Module):
        if not self._update_lora_weights:
            return None
        unwrapped_model = self.strategy.unwrap_model(model)
        return LoRAConstructor.extract_lora_config(unwrapped_model)