base.py 2.85 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
5
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional

import torch
6
7
from coati.models import Critic, RewardModel
from transformers import PreTrainedModel
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
8
9
10
11
12


@dataclass
class Experience:
    """Experience is a batch of data.
13
    These data should have the sequence length and number of actions.
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
14
15
16
17
18
19
20
    Left padding for sequences is applied.

    Shapes of each tensor:
    sequences: (B, S)
    action_log_probs: (B, A)
    values: (B)
    reward: (B)
21
    advantages: (B)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
22
23
24
25
26
    attention_mask: (B, S)
    action_mask: (B, A)

    "A" is the number of actions.
    """
27

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
28
29
30
31
    sequences: torch.Tensor
    action_log_probs: torch.Tensor
    values: torch.Tensor
    reward: torch.Tensor
32
    kl: torch.Tensor
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
33
34
35
36
37
38
39
40
41
42
43
    advantages: torch.Tensor
    attention_mask: Optional[torch.LongTensor]
    action_mask: Optional[torch.BoolTensor]

    @torch.no_grad()
    def to_device(self, device: torch.device) -> None:
        self.sequences = self.sequences.to(device)
        self.action_log_probs = self.action_log_probs.to(device)
        self.values = self.values.to(device)
        self.reward = self.reward.to(device)
        self.advantages = self.advantages.to(device)
44
        self.kl = self.kl.to(device)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
45
46
47
48
49
50
51
52
53
54
55
        if self.attention_mask is not None:
            self.attention_mask = self.attention_mask.to(device)
        if self.action_mask is not None:
            self.action_mask = self.action_mask.to(device)

    def pin_memory(self):
        self.sequences = self.sequences.pin_memory()
        self.action_log_probs = self.action_log_probs.pin_memory()
        self.values = self.values.pin_memory()
        self.reward = self.reward.pin_memory()
        self.advantages = self.advantages.pin_memory()
56
        self.kl = self.kl.pin_memory()
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
57
58
59
60
61
62
63
64
        if self.attention_mask is not None:
            self.attention_mask = self.attention_mask.pin_memory()
        if self.action_mask is not None:
            self.action_mask = self.action_mask.pin_memory()
        return self


class ExperienceMaker(ABC):
65
66
67
68
69
70
71
    """
    Base class for experience makers.
    """

    def __init__(
        self, actor: PreTrainedModel, critic: Critic, reward_model: RewardModel, initial_model: PreTrainedModel
    ) -> None:
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
72
73
74
75
76
77
78
        super().__init__()
        self.actor = actor
        self.critic = critic
        self.reward_model = reward_model
        self.initial_model = initial_model

    @abstractmethod
79
    def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
80
81
82
83
84
85
86
87
88
89
90
        """
        Abstract method to generate an experience.

        Args:
            input_ids (torch.Tensor): The input tensor.
            attention_mask (torch.Tensor): The attention mask tensor.
            **generate_kwargs: Additional keyword arguments for generating the experience.

        Returns:
            Experience: The generated experience.
        """