detached_replay_buffer.py 3.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import random
from typing import List, Any
# from torch.multiprocessing import Queue
from ray.util.queue import Queue
import ray
import asyncio
from coati.experience_maker.base import Experience
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.replay_buffer import ReplayBuffer
from threading import Lock
import copy

class DetachedReplayBuffer:
    '''
        Detached replay buffer. Share Experience across workers on the same node. 
        Therefore a trainer node is expected to have only one instance. 
        It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
    
    Args:
        sample_batch_size: Batch size when sampling. Exp won't enqueue until they formed a batch.
        tp_world_size: Number of workers in the same tp group
        limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
        cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
    '''

    def __init__(self, sample_batch_size: int, tp_world_size: int = 1, limit : int = 0, cpu_offload: bool = True) -> None:
        self.cpu_offload = cpu_offload
        self.sample_batch_size = sample_batch_size
        self.limit = limit
        self.items = Queue(self.limit, actor_options={"num_cpus":1})
        self.batch_collector : List[BufferItem] = []

        '''
        Workers in the same tp group share this buffer and need same sample for one step.
            Therefore a held_sample should be returned tp_world_size times before it could be dropped.
37
            worker_state records whether a worker got the held_sample
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        '''
        self.tp_world_size = tp_world_size
        self.worker_state = [False] * self.tp_world_size
        self.held_sample = None
        self._worker_state_lock = Lock()

    @torch.no_grad()
    def append(self, experience: Experience) -> None:
        '''
        Expected to be called remotely.
        '''
        if self.cpu_offload:
            experience.to_device(torch.device('cpu'))
        items = split_experience_batch(experience)
        self.batch_collector.extend(items)
        while len(self.batch_collector) >= self.sample_batch_size:
            items = self.batch_collector[:self.sample_batch_size]
            experience = make_experience_batch(items)
            self.items.put(experience, block=True)
            self.batch_collector = self.batch_collector[self.sample_batch_size:]

    def clear(self) -> None:
        # self.items.close()
        self.items.shutdown()
        self.items = Queue(self.limit)
        self.worker_state = [False] * self.tp_world_size
        self.batch_collector = []
     
    @torch.no_grad()
    def sample(self, worker_rank = 0, to_device = "cpu") -> Experience:
        self._worker_state_lock.acquire()
        if not any(self.worker_state):
            self.held_sample = self._sample_and_erase()
        self.worker_state[worker_rank] = True
        if all(self.worker_state):
            self.worker_state = [False] * self.tp_world_size
            ret = self.held_sample
        else:
            ret = copy.deepcopy(self.held_sample)
        self._worker_state_lock.release()
        ret.to_device(to_device)
        return ret

    @torch.no_grad()
    def _sample_and_erase(self) -> Experience:
        ret = self.items.get(block=True)
        return ret

    def get_length(self) -> int:
        ret = self.items.qsize()
        return ret