test_models.py 8.43 KB
Newer Older
1
2
3
4
5
6
7
8
import copy
from typing import Any, Callable, Dict, Tuple

import pytest
import torch
import torch.nn as nn
from coati.models.base import Actor, Critic, RewardModel, get_base_model
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
9
10
from coati.models.chatglm import ChatGLMActor
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
11
12
from coati.models.generation import generate
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
13
from coati.models.llama import LlamaActor
14
15
16
17
from coati.models.lora import LoraLinear, convert_to_lora_module
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
18

19
20
21

@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [32])
22
23
24
25
26
@pytest.mark.parametrize(
    "actor_maker",
    [
        lambda: BLOOMActor(),
        lambda: GPTActor(),
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
        # HACK: skip llama due to long execution time
        # lambda: LlamaActor(),
        lambda: OPTActor(),
        # lambda: ChatGLMActor(),
    ],
)
@pytest.mark.parametrize(
    "generate_kwargs",
    [
        {
            "max_length": 64,
            "use_cache": True,
            "do_sample": True,
            "temperature": 1.0,
            "top_k": 50,
        }
    ],
)
45
def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
46
47
48
49
50
51
52
    actor = actor_maker()
    input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
    sequences = generate(actor.cuda(), input_ids, **generate_kwargs)
    assert sequences.shape == (batch_size, generate_kwargs["max_length"])


def test_utils():
53
    fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))}
54
55
56
57
58
59
60
    fn_output = masked_mean(dim=0, **fn_input)
    assert fn_output.dim() == 0
    assert torch.allclose(fn_output, torch.tensor(1.0))

    batch_size = 4
    num_labels = 10
    fn_input = {
61
        "r": torch.ones((batch_size,)),
62
63
64
        "kl_coef": 1.0,
        "log_probs": torch.randn((batch_size, num_labels)),
        "log_probs_base": torch.randn((batch_size, num_labels)),
65
        "action_mask": torch.randint(0, 2, (batch_size, num_labels)),
66
67
    }
    fn_output = compute_reward(**fn_input)
68
    assert fn_output.shape == (batch_size,)
69
70
71
72
73
74

    batch_size = 4
    seq_len = 32
    num_labels = 10
    num_actions = 2
    fn_input = {
75
        "output": {"logits": torch.randn((batch_size, seq_len, num_labels))},
76
77
78
79
80
81
82
83
84
85
        "sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
        "num_actions": num_actions,
    }
    fn_output = calc_action_log_probs(**fn_input)
    assert fn_output.shape == (batch_size, num_actions)


@pytest.mark.parametrize("lora_rank", [4])
@pytest.mark.parametrize("num_dim", [32])
@pytest.mark.parametrize("num_layers", [4])
86
87
def test_lora(lora_rank: int, num_dim: int, num_layers: int):
    model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)])
88
89
90
91
92
93
94
95
96
97
98
99
    lora_model = convert_to_lora_module(model, lora_rank)
    assert isinstance(lora_model, nn.ModuleList)
    for i in range(num_layers):
        assert isinstance(lora_model[i], LoraLinear)
        assert lora_model[i].lora_A.shape == (lora_rank, num_dim)
        assert lora_model[i].lora_B.shape == (num_dim, lora_rank)

    old_model = copy.deepcopy(lora_model)
    for i in range(num_layers):
        assert isinstance(lora_model[i], LoraLinear)
        assert torch.allclose(old_model[i].weight, lora_model[i].weight)
        assert torch.allclose(old_model[i].bias, lora_model[i].bias)
100
        assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A)
101
102
103
104
105
106
107
108
109
110
111
    optimizer = torch.optim.Adam(lora_model.parameters())
    x = torch.randn(8, num_dim)
    for i in range(num_layers):
        x = lora_model[i](x)
    loss = x.sum()
    loss.backward()
    optimizer.step()
    for i in range(num_layers):
        assert isinstance(lora_model[i], LoraLinear)
        assert torch.allclose(old_model[i].weight, lora_model[i].weight)
        assert torch.allclose(old_model[i].bias, lora_model[i].bias)
112
113
114
        assert not torch.allclose(
            old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A
        )
115
116
117
118


@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [128])
119
120
121
122
123
@pytest.mark.parametrize(
    "models_maker",
    [
        lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
        lambda: (GPTActor(), GPTCritic(), GPTRM()),
124
125
126
127
128
129
        # HACK: skip llama due to long execution time
        # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
        lambda: (OPTActor(), OPTCritic(), OPTRM()),
        lambda: (ChatGLMActor(), None, None),
    ],
)
130
@torch.no_grad()
131
def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int):
132
133
    actor_input = {
        "input_ids": torch.randint(0, 100, (batch_size, seq_len)),
134
        "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
135
136
137
138
    }
    critic_input = {
        "sequences": torch.randint(0, 100, (batch_size, seq_len)),
        "action_mask": torch.randint(0, 2, (batch_size, seq_len)),
139
        "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
140
141
142
    }
    rm_input = {
        "sequences": torch.randint(0, 100, (batch_size, seq_len)),
143
        "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
144
145
146
    }

    actor, critic, rm = models_maker()
147
148
    if isinstance(actor, ChatGLMActor):
        actor = actor.float()
149
        tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
150
        chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1)
151
152
153
154
155
156
157
158
159
160
161
        actor_input = {
            "input_ids": torch.cat(
                (
                    torch.randint(0, 100, (batch_size, seq_len // 2)),
                    chatglm_special_token,
                    torch.randint(0, 100, (batch_size, seq_len // 2 - 2)),
                ),
                dim=1,
            ),
            "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)),
        }
162
    assert isinstance(actor, Actor)
163
    get_base_model(actor)
164
165
    actor_output = actor(**actor_input)
    assert actor_output.logits.shape[:2] == (batch_size, seq_len)
166
167
168

    if critic:
        assert isinstance(critic, Critic)
169
        get_base_model(critic)
170
        critic_output = critic(**critic_input)
171
172
        assert critic_output.shape == (batch_size,)

173
174
    if rm:
        assert isinstance(rm, RewardModel)
175
        get_base_model(rm)
176
        rm_output = rm(**rm_input)
177
        assert rm_output.shape == (batch_size,)
178
179
180
181
182


@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize("num_labels", [100])
183
def test_loss(batch_size: int, seq_len: int, num_labels: int):
184
185
186
    loss = GPTLMLoss()
    loss_input = {
        "logits": torch.randn(batch_size, seq_len, num_labels),
187
        "labels": torch.randint(0, num_labels, (batch_size, seq_len)),
188
    }
189
    loss(**loss_input)
190
191
192

    loss = PolicyLoss()
    loss_input = {
193
194
195
196
197
198
199
200
201
        "log_probs": torch.randn(
            batch_size,
        ),
        "old_log_probs": torch.randn(
            batch_size,
        ),
        "advantages": torch.randn(
            batch_size,
        ),
202
    }
203
    loss(**loss_input)
204
205
206

    loss = ValueLoss()
    loss_input = {
207
208
209
210
211
212
213
214
215
        "values": torch.randn(
            batch_size,
        ),
        "old_values": torch.randn(
            batch_size,
        ),
        "reward": torch.randn(
            batch_size,
        ),
216
    }
217
    loss(**loss_input)
218
219
220

    loss = LogSigLoss()
    loss_input = {
221
222
223
224
225
226
        "chosen_reward": torch.randn(
            batch_size,
        ),
        "reject_reward": torch.randn(
            batch_size,
        ),
227
    }
228
    loss(**loss_input)
229
230
231

    loss = LogExpLoss()
    loss_input = {
232
233
234
235
236
237
        "chosen_reward": torch.randn(
            batch_size,
        ),
        "reject_reward": torch.randn(
            batch_size,
        ),
238
    }
239
    loss(**loss_input)
240
241
242


if __name__ == "__main__":
243
244
    generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50)
    test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs)
245
246
247
248
249

    test_utils()

    test_lora(lora_rank=2, num_dim=8, num_layers=2)

250
    test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
251

252
    test_loss(batch_size=8, seq_len=128, num_labels=100)