experience_maker_holder.py 7.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
from typing import Any, Callable, Dict, List, Optional, Union
import ray
from ray.exceptions import GetTimeoutError
from torch import Tensor
import torch.nn as nn
from coati.models.base import Actor, Critic, RewardModel
from coati.trainer.strategies.sampler import DistributedSampler
from coati.trainer.strategies import Strategy
from coati.experience_maker import NaiveExperienceMaker, Experience, ExperienceMaker

from copy import deepcopy
from threading import Lock
import time
import os


from .utils import is_rank_0, get_strategy_from_args, set_dist_env


@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
class ExperienceMakerHolder:
    '''
    Args:
        detached_trainer_name_list: str list to get ray actor handleskkk
        strategy: 
        experience_batch_size: batch size of generated experience
        kl_coef: the coefficient of kl divergence loss
    '''

    def __init__(self,
                 detached_trainer_name_list: List[str],
                 strategy: str,
                 env_info: Dict[str, str] = None,
                 experience_batch_size: int = 8,
                 kl_coef: float = 0.1,
                 **generate_kwargs):
        # set environment variables
        if env_info:
            set_dist_env(env_info=env_info)
        self.target_trainer_list = []
        for name in detached_trainer_name_list:
            self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
        self.strategy_str = strategy
        self.strategy = get_strategy_from_args(strategy)
        self.experience_batch_size = experience_batch_size
        self.kl_coef = kl_coef
        self.generate_kwargs = generate_kwargs
        # Need a trainer to give an actor and a critic via initialize_experience_maker(...)
        actor, critic, reward_model, initial_model = None, None, None, None
        self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef)
        self._model_visit_lock = Lock()
        self.fully_initialized = False
        if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
            print('[maker] Waiting for INIT')

    def _get_ready(self):
        while not self.fully_initialized:
            time.sleep(1.0)

    def update_target_trainer_list(self, detached_trainer_name_list):
        self.target_trainer_list = []
        for name in detached_trainer_name_list:
            self.target_trainer_list.append(ray.get_actor(name))

    # copy from ../trainer/base.py
    @ray.method(concurrency_group="compute")
    def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
        self._get_ready()
        if isinstance(inputs, Tensor):
            return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
        elif isinstance(inputs, dict):
            return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
        else:
            raise ValueError(f'Unsupported input type "{type(inputs)}"')

    @ray.method(concurrency_group="experience_io")
    def _send_experience(self, experience):
        '''
        ignore it

        # choose a trainer that has the least experience batch in its detached_replay_buffer
        chosen_trainer = None
        min_length = None
        if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
            print("[maker] choosing tartget trainer")
        while chosen_trainer is None:
            for target_trainer in self.target_trainer_list:
                try:
                    temp_length = ray.get(target_trainer.buffer_get_length.remote(), timeout=0.1)
                    if min_length is None:
                        min_length = temp_length
                        chosen_trainer = target_trainer
                    else:
                        if temp_length < min_length:
                            min_length = temp_length
                            chosen_trainer = target_trainer
                except GetTimeoutError:
                    pass
                    
        if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
            print(f"[maker] sending exp to {chosen_trainer}")
        chosen_trainer.buffer_append.remote(experience)
        '''
        # 
        if not hasattr(self, "_target_idx"):
            self._target_idx = 0
        chosen_trainer = self.target_trainer_list[self._target_idx]
        if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
            print(f"[maker] sending exp to {chosen_trainer}")
        chosen_trainer.buffer_append.remote(experience)
        self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list)

    def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None, times=5000 * 50000):
        self._get_ready()
        sampler = self.strategy.setup_sampler(dataset)
        for _ in range(times):
            rand_prompts = sampler.sample(self.experience_batch_size)
            if tokenizer is not None:
                inputs = tokenizer(rand_prompts)
            else:
                inputs = rand_prompts
            self._model_visit_lock.acquire()
            experience = self._make_experience(inputs=inputs)
            self._model_visit_lock.release()
            self._send_experience(experience=experience)

    @ray.method(concurrency_group="model_io")
    def initialize_experience_maker(self, init_actor: Actor, init_critic: Critic):
        '''
        called by trainer. Only once.
        '''
        # TODO: reduce malloc
        if self.fully_initialized:
            return
        if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
            print('[maker] INIT')
        with torch.no_grad():
            with self.strategy.model_init_context():
                actor = init_actor
                critic = init_critic
                initial_model = deepcopy(actor)
                reward_model = RewardModel(deepcopy(critic.model),
                                           deepcopy(critic.value_head)).to(torch.cuda.current_device())
            if self.strategy_str != 'colossalai_gemini':
                actor.to(torch.float16).to(torch.cuda.current_device())
                critic.to(torch.float16).to(torch.cuda.current_device())
                initial_model.to(torch.float16).to(torch.cuda.current_device())
                reward_model.to(torch.float16).to(torch.cuda.current_device())

            self.experience_maker.actor = self.strategy.prepare(actor)
            self.experience_maker.critic = self.strategy.prepare(critic)
            self.experience_maker.initial_model = self.strategy.prepare(initial_model)
            self.experience_maker.reward_model = self.strategy.prepare(reward_model)
        self.fully_initialized = True

    @ray.method(concurrency_group="model_io")
    def update_experience_maker(self, new_actor: Actor, new_critic: Critic):
        '''
            called by trainer
        '''
        # TODO: reduce malloc
        self._model_visit_lock.acquire()
        with torch.no_grad():
            if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
                print("[maker] UPDATE ")
            if self.strategy_str != 'colossalai_gemini':
                new_actor.to(torch.float16).to(torch.cuda.current_device())
                new_critic.to(torch.float16).to(torch.cuda.current_device())
            self.experience_maker.actor = self.strategy.prepare(new_actor)
            self.experience_maker.critic = self.strategy.prepare(new_critic)
        self._model_visit_lock.release()