test_causal_lm.py 11.3 KB
Newer Older
1
2
3
4
import pytest
import torch

from copy import copy
5
from transformers import AutoTokenizer
6

7
8
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
9
10


11
12
13
14
15
16
17
18
19
20
21
22
@pytest.fixture(scope="session")
def default_causal_lm():
    return CausalLM("gpt2")


@pytest.fixture(scope="session")
def gpt2_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
    tokenizer.pad_token_id = 50256
    return tokenizer


23
@pytest.fixture
24
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
25
26
27
28
29
    return generate_pb2.Request(
        id=0,
        inputs="Test",
        input_length=1,
        parameters=default_pb_parameters,
30
        stopping_parameters=default_pb_stop_parameters,
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    )


@pytest.fixture
def default_pb_batch(default_pb_request):
    return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)


@pytest.fixture
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
    return CausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu"))


@pytest.fixture
def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
    req_0 = copy(default_pb_request)
    req_1 = default_pb_request
    req_1.id = 1
49
    req_1.stopping_parameters.max_new_tokens = 5
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

    batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
    return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))


def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
    batch = default_causal_lm_batch

    assert batch.batch_id == default_pb_batch.id
    assert batch.requests == default_pb_batch.requests

    assert len(batch.input_ids) == default_pb_batch.size
    assert batch.input_ids[0][-1] == 14402
    assert torch.all(batch.input_ids[0][:-1] == 50256)

65
66
    assert batch.attention_mask[0, 0] == 1
    assert torch.all(batch.attention_mask[0, 1:] == 0)
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    assert batch.past_key_values is None

    assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0])

    assert batch.input_lengths == [1]

    assert batch.size == default_pb_batch.size
    assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size

    assert batch.max_sequence_length == batch.input_lengths[0]


def test_batch_concatenate_no_prefill(default_causal_lm_batch):
    with pytest.raises(ValueError):
        CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch])


def test_causal_lm_batch_type(default_causal_lm):
    assert default_causal_lm.batch_type == CausalLMBatch


def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
90
    sequence_length = len(default_causal_lm_batch.all_input_ids[0])
91
    generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)
92

93
    assert len(generations) == len(next_batch)
94
95
96
    assert isinstance(next_batch, CausalLMBatch)

    assert len(next_batch.all_input_ids) == next_batch.size
97
98
    assert len(next_batch.all_input_ids[0]) == sequence_length + 1
    assert len(next_batch.attention_mask[0]) == 11
99
    assert next_batch.all_input_ids[0][-1] == 13
100
101
102
    assert next_batch.all_input_ids[0][-2] == 14402
    assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)

103
104
    assert torch.all(next_batch.attention_mask[0][0:2] == 1)
    assert torch.all(next_batch.attention_mask[0][2:] == 0)
105
106

    assert next_batch.input_ids.shape == (next_batch.size, 1)
107
    assert next_batch.input_ids[0, 0] == 13
108
109
110
111
112

    assert next_batch.input_lengths == [2]
    assert next_batch.max_sequence_length == next_batch.input_lengths[0]

    assert next_batch.past_key_values is not None
113
114
115
116
117
118
    assert all(
        [p[0].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
    )
    assert all(
        [p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
    )
119
120
121
122
123
    assert all([generation.generated_text is None for generation in generations])
    assert all([len(generation.prefill_tokens) == 1 for generation in generations])
    assert all([generation.token_id.item() == 13 for generation in generations])
    assert all([generation.token_text == "." for generation in generations])
    assert generations[0].request_id == 0
124
125
126
127
128
129
130


def test_causal_lm_generate_token_completion(
    default_causal_lm, default_causal_lm_batch
):
    next_batch = default_causal_lm_batch
    for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
131
132
        generations, next_batch = default_causal_lm.generate_token(next_batch)
        assert len(generations) == len(next_batch)
133

134
    generations, next_batch = default_causal_lm.generate_token(next_batch)
135
136
    assert next_batch is None

137
    assert len(generations) == 1
138
    assert generations[0].generated_text.text == ".java:784) at net.minecraft."
139
    assert generations[0].request_id == default_causal_lm_batch.requests[0].id
140
    assert (
141
        generations[0].generated_text.generated_tokens
142
143
144
145
146
147
148
149
150
151
152
153
        == default_causal_lm_batch.stopping_criterias[0].max_new_tokens
    )


def test_causal_lm_generate_token_completion_multi(
    default_causal_lm, default_multi_requests_causal_lm_batch
):
    next_batch = default_multi_requests_causal_lm_batch

    for i in range(
        default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
    ):
154
155
        generations, next_batch = default_causal_lm.generate_token(next_batch)
        assert len(generations) == len(next_batch)
156

157
    generations, next_batch = default_causal_lm.generate_token(next_batch)
158
159
    assert next_batch is not None

160
    assert len(generations) == 2
161
    assert generations[1].generated_text.text == ".java:784)"
162
    assert (
163
164
        generations[1].request_id
        == default_multi_requests_causal_lm_batch.requests[1].id
165
166
    )
    assert (
167
        generations[1].generated_text.generated_tokens
168
169
170
171
172
173
174
175
        == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
    )

    for _ in range(
        default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
        - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
        - 1
    ):
176
177
        generations, next_batch = default_causal_lm.generate_token(next_batch)
        assert len(generations) == len(next_batch)
178

179
    generations, next_batch = default_causal_lm.generate_token(next_batch)
180
181
    assert next_batch is None

182
    assert len(generations) == 1
183
    assert generations[0].generated_text.text == ".java:784) at net.minecraft."
184
    assert (
185
186
        generations[0].request_id
        == default_multi_requests_causal_lm_batch.requests[0].id
187
188
    )
    assert (
189
        generations[0].generated_text.generated_tokens
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
    )


def test_batch_concatenate(
    default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch
):
    next_batch_0 = default_causal_lm_batch
    _, next_batch_0 = default_causal_lm.generate_token(next_batch_0)
    _, next_batch_0 = default_causal_lm.generate_token(next_batch_0)

    next_batch_1 = default_multi_requests_causal_lm_batch
    _, next_batch_1 = default_causal_lm.generate_token(next_batch_1)

    next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1])

    assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
    assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
    assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])

210
211
212
213
214
215
216
    assert torch.all(
        next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
    )
    assert torch.all(
        next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
    )
    assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
217
218

    assert next_batch.batch_id == 0
219
220
    assert next_batch.input_ids[0, 0] == 12355
    assert torch.all(next_batch.input_ids[1:] == 13)
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

    assert next_batch.input_lengths == [3, 2, 2]
    assert next_batch.max_sequence_length == 3

    assert next_batch.requests[0] == next_batch_0.requests[0]
    assert next_batch.requests[1:] == next_batch_1.requests

    assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
    assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers

    assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
    assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias

    assert next_batch.past_key_values is not None
    assert all([p[0].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
    assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])

    for i, past in enumerate(next_batch.past_key_values):
        assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:], past[0][0])
        assert torch.equal(
            next_batch_1.past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :]
        )

        assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:], past[1][0])
        assert torch.equal(
            next_batch_1.past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :]
        )

    for _ in range(
        default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
    ):
252
253
        generations, next_batch = default_causal_lm.generate_token(next_batch)
        assert len(generations) == len(next_batch)
254

255
    generations, next_batch = default_causal_lm.generate_token(next_batch)
256
257
    assert next_batch is not None

258
    assert len(generations) == 3
259
    assert generations[2].generated_text.text == ".java:784)"
260
    assert (
261
262
        generations[2].request_id
        == default_multi_requests_causal_lm_batch.requests[1].id
263
264
    )
    assert (
265
        generations[2].generated_text.generated_tokens
266
267
268
269
270
271
272
273
        == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
    )

    for _ in range(
        default_causal_lm_batch.stopping_criterias[0].max_new_tokens
        - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
        - 2
    ):
274
275
        generations, next_batch = default_causal_lm.generate_token(next_batch)
        assert len(generations) == len(next_batch)
276

277
    generations, next_batch = default_causal_lm.generate_token(next_batch)
278
279
    assert next_batch is not None

280
    assert len(generations) == 2
281
    assert generations[0].generated_text.text == ".java:784) at net.minecraft."
282
    assert generations[0].request_id == default_causal_lm_batch.requests[0].id
283
    assert (
284
        generations[0].generated_text.generated_tokens
285
286
287
288
289
290
291
292
293
        == default_causal_lm_batch.stopping_criterias[0].max_new_tokens
    )

    for _ in range(
        default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
        - default_causal_lm_batch.stopping_criterias[0].max_new_tokens
        - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
        - 4
    ):
294
295
        generations, next_batch = default_causal_lm.generate_token(next_batch)
        assert len(generations) == len(next_batch)
296

297
    generations, next_batch = default_causal_lm.generate_token(next_batch)
298
299
    assert next_batch is None

300
    assert len(generations) == 1
301
    assert generations[0].generated_text.text == ".java:784) at net.minecraft."
302
    assert (
303
304
        generations[0].request_id
        == default_multi_requests_causal_lm_batch.requests[0].id
305
306
    )
    assert (
307
        generations[0].generated_text.generated_tokens
308
309
        == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
    )