utils.py 2.36 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from dataclasses import dataclass
from typing import List, Optional

import torch
import torch.nn.functional as F
from coati.experience_maker.base import Experience


@dataclass
class BufferItem:
    """BufferItem is an item of experience data.

    Shapes of each tensor:
    sequences: (S)
    action_log_probs: (A)
    values: (1)
    reward: (1)
18
    advantages: (1)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
19
20
21
22
23
    attention_mask: (S)
    action_mask: (A)

    "A" is the number of actions.
    """
24

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
25
26
27
28
    sequences: torch.Tensor
    action_log_probs: torch.Tensor
    values: torch.Tensor
    reward: torch.Tensor
29
    kl: torch.Tensor
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
30
31
32
33
34
35
36
37
    advantages: torch.Tensor
    attention_mask: Optional[torch.LongTensor]
    action_mask: Optional[torch.BoolTensor]


def split_experience_batch(experience: Experience) -> List[BufferItem]:
    batch_size = experience.sequences.size(0)
    batch_kwargs = [{} for _ in range(batch_size)]
38
    keys = ("sequences", "action_log_probs", "values", "reward", "kl", "advantages", "attention_mask", "action_mask")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    for key in keys:
        value = getattr(experience, key)
        if isinstance(value, torch.Tensor):
            vals = torch.unbind(value)
        else:
            # None
            vals = [value for _ in range(batch_size)]
        assert batch_size == len(vals)
        for i, v in enumerate(vals):
            batch_kwargs[i][key] = v
    items = [BufferItem(**kwargs) for kwargs in batch_kwargs]
    return items


53
54
def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor:
    assert side in ("left", "right")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
55
56
57
58
    max_len = max(seq.size(0) for seq in sequences)
    padded_sequences = []
    for seq in sequences:
        pad_len = max_len - seq.size(0)
59
        padding = (pad_len, 0) if side == "left" else (0, pad_len)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
60
61
62
63
64
65
        padded_sequences.append(F.pad(seq, padding))
    return torch.stack(padded_sequences, dim=0)


def make_experience_batch(items: List[BufferItem]) -> Experience:
    kwargs = {}
66
    to_pad_keys = set(("action_log_probs", "action_mask"))
67
    keys = ("sequences", "action_log_probs", "values", "reward", "kl", "advantages", "attention_mask", "action_mask")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
68
69
70
    for key in keys:
        vals = [getattr(item, key) for item in items]
        if key in to_pad_keys:
71
            batch_data = _zero_pad_sequences(vals)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
72
73
74
75
        else:
            batch_data = torch.stack(vals, dim=0)
        kwargs[key] = batch_data
    return Experience(**kwargs)