detached_trainer_base.py 5.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
from tqdm import tqdm
from coati.trainer.callbacks import Callback
from coati.experience_maker import Experience
import ray
import os

from .detached_replay_buffer import DetachedReplayBuffer
from .utils import is_rank_0

class DetachedTrainer(ABC):
    '''
        Base class for detached rlhf trainers. 
        'detach' means that the experience maker is detached compared to a normal Trainer.
        Please set name attribute during init:
            >>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote()
            So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name.
    Args:
        detached_strategy (DetachedStrategy): the strategy to use for training
        detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
        experience_batch_size (int, defaults to 8): the batch size to use for experience generation
        max_epochs (int, defaults to 1): the number of epochs of training process
        data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
        callbacks (List[Callback], defaults to []): the callbacks to call during training process
        generate_kwargs (dict, optional): the kwargs to use while model generating
    '''

    def __init__(self,
                 experience_maker_holder_name_list: List[str],
                 train_batch_size: int = 8,
                 buffer_limit: int = 0,
                 buffer_cpu_offload: bool = True,
                 experience_batch_size: int = 8,
                 max_epochs: int = 1,
                 dataloader_pin_memory: bool = True,
                 callbacks: List[Callback] = [],
                 **generate_kwargs) -> None:
        super().__init__()
        self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit, cpu_offload=buffer_cpu_offload)
        self.experience_batch_size = experience_batch_size
        self.max_epochs = max_epochs
        self.dataloader_pin_memory = dataloader_pin_memory
        self.callbacks = callbacks
        self.generate_kwargs = generate_kwargs
        self.target_holder_name_list = experience_maker_holder_name_list
        self.target_holder_list = []

    def update_target_holder_list(self, experience_maker_holder_name_list):
        self.target_holder_name_list = experience_maker_holder_name_list
        self.target_holder_list = []
        for name in self.target_holder_name_list:
            self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))

    @abstractmethod
    def _update_remote_makers(self):
        pass

    @abstractmethod
    def training_step(self, experience: Experience) -> Dict[str, Any]:
        pass

    def _learn(self):
        pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
        for _ in pbar:
            if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
                print("[trainer] sampling exp")
            experience = self._buffer_sample()
            if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
                print("[trainer] training step")
            metrics = self.training_step(experience)
            if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
                print("[trainer] step over")
            pbar.set_postfix(metrics)

    def fit(self, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None:
        self._on_fit_start()
        for episode in range(num_episodes):
            self._on_episode_start(episode)
            for timestep in tqdm(range(max_timesteps // update_timesteps),
                                 desc=f'Episode [{episode+1}/{num_episodes}]',
                                 disable=not is_rank_0()):
                self._learn()
                self._update_remote_makers()
            self._on_episode_end(episode)
        self._on_fit_end()

    @ray.method(concurrency_group="buffer_length")
    def buffer_get_length(self):
        # called by ExperienceMakerHolder
        if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
            print("[trainer]                telling length")
        return self.detached_replay_buffer.get_length()

    @ray.method(concurrency_group="buffer_append")
    def buffer_append(self, experience: Experience):
        # called by ExperienceMakerHolder
        if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
            # print(f"[trainer] receiving exp. Current buffer length: {self.detached_replay_buffer.get_length()}")
            print(f"[trainer]               receiving exp.")
        self.detached_replay_buffer.append(experience)

    @ray.method(concurrency_group="buffer_sample")
    def _buffer_sample(self):
        return self.detached_replay_buffer.sample()

    def _on_fit_start(self) -> None:
        for callback in self.callbacks:
            callback.on_fit_start()

    def _on_fit_end(self) -> None:
        for callback in self.callbacks:
            callback.on_fit_end()

    def _on_episode_start(self, episode: int) -> None:
        for callback in self.callbacks:
            callback.on_episode_start(episode)

    def _on_episode_end(self, episode: int) -> None:
        for callback in self.callbacks:
            callback.on_episode_end(episode)