conftest.py 18.3 KB
Newer Older
1
import asyncio
2
import contextlib
3
4
import json
import math
5
6
import os
import random
7
import shutil
8
9
import subprocess
import sys
10
import tempfile
11
import time
12
from typing import Dict, List, Optional
13

14
15
16
import docker
import pytest
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
17
from docker.errors import NotFound
18
from syrupy.extensions.json import JSONSnapshotExtension
19
from text_generation import AsyncClient
drbh's avatar
drbh committed
20
21
from text_generation.types import (
    BestOfSequence,
drbh's avatar
drbh committed
22
23
    ChatComplete,
    ChatCompletionChunk,
24
    ChatCompletionComplete,
25
    Completion,
26
27
28
29
30
    Details,
    Grammar,
    InputToken,
    Response,
    Token,
drbh's avatar
drbh committed
31
)
32
33

DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
34
HF_TOKEN = os.getenv("HF_TOKEN", None)
35
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
36
DOCKER_DEVICES = os.getenv("DOCKER_DEVICES")
37
38


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def pytest_addoption(parser):
    parser.addoption(
        "--release", action="store_true", default=False, help="run release tests"
    )


def pytest_configure(config):
    config.addinivalue_line("markers", "release: mark test as a release-only test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--release"):
        # --release given in cli: do not skip release tests
        return
    skip_release = pytest.mark.skip(reason="need --release option to run")
    for item in items:
        if "release" in item.keywords:
            item.add_marker(skip_release)


59
class ResponseComparator(JSONSnapshotExtension):
60
    rtol = 0.2
61
    ignore_logprob = False
OlivierDehaene's avatar
OlivierDehaene committed
62

63
64
65
66
    def serialize(
        self,
        data,
        *,
67
        include=None,
68
69
70
        exclude=None,
        matcher=None,
    ):
71
72
73
74
75
76
77
        if (
            isinstance(data, Response)
            or isinstance(data, ChatComplete)
            or isinstance(data, ChatCompletionChunk)
            or isinstance(data, ChatCompletionComplete)
        ):
            data = data.model_dump()
78

79
        if isinstance(data, List):
80
            data = [d.model_dump() for d in data]
81
82

        data = self._filter(
83
84
85
86
87
88
            data=data,
            depth=0,
            path=(),
            exclude=exclude,
            include=include,
            matcher=matcher,
89
90
91
92
93
94
95
96
97
98
99
        )
        return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + "\n"

    def matches(
        self,
        *,
        serialized_data,
        snapshot_data,
    ) -> bool:
        def convert_data(data):
            data = json.loads(data)
drbh's avatar
drbh committed
100
101
            if isinstance(data, Dict) and "choices" in data:
                choices = data["choices"]
102
103
104
105
106
                if isinstance(choices, List) and len(choices) >= 1:
                    if "delta" in choices[0]:
                        return ChatCompletionChunk(**data)
                    if "text" in choices[0]:
                        return Completion(**data)
drbh's avatar
drbh committed
107
                return ChatComplete(**data)
108
109
110
111

            if isinstance(data, Dict):
                return Response(**data)
            if isinstance(data, List):
112
113
114
115
116
117
                if (
                    len(data) > 0
                    and "object" in data[0]
                    and data[0]["object"] == "text_completion"
                ):
                    return [Completion(**d) for d in data]
118
119
120
121
122
123
124
                return [Response(**d) for d in data]
            raise NotImplementedError

        def eq_token(token: Token, other: Token) -> bool:
            return (
                token.id == other.id
                and token.text == other.text
125
126
                and (
                    self.ignore_logprob
Nicolas Patry's avatar
Nicolas Patry committed
127
                    or (token.logprob == other.logprob and token.logprob is None)
128
129
                    or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
                )
130
131
132
                and token.special == other.special
            )

133
        def eq_prefill_token(prefill_token: InputToken, other: InputToken) -> bool:
134
135
136
137
138
            try:
                return (
                    prefill_token.id == other.id
                    and prefill_token.text == other.text
                    and (
139
140
141
142
143
                        self.ignore_logprob
                        or math.isclose(
                            prefill_token.logprob,
                            other.logprob,
                            rel_tol=self.rtol,
OlivierDehaene's avatar
OlivierDehaene committed
144
                        )
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
                        if prefill_token.logprob is not None
                        else prefill_token.logprob == other.logprob
                    )
                )
            except TypeError:
                return False

        def eq_best_of(details: BestOfSequence, other: BestOfSequence) -> bool:
            return (
                details.finish_reason == other.finish_reason
                and details.generated_tokens == other.generated_tokens
                and details.seed == other.seed
                and len(details.prefill) == len(other.prefill)
                and all(
                    [
                        eq_prefill_token(d, o)
                        for d, o in zip(details.prefill, other.prefill)
                    ]
                )
                and len(details.tokens) == len(other.tokens)
                and all([eq_token(d, o) for d, o in zip(details.tokens, other.tokens)])
            )

        def eq_details(details: Details, other: Details) -> bool:
            return (
                details.finish_reason == other.finish_reason
                and details.generated_tokens == other.generated_tokens
                and details.seed == other.seed
                and len(details.prefill) == len(other.prefill)
                and all(
                    [
                        eq_prefill_token(d, o)
                        for d, o in zip(details.prefill, other.prefill)
                    ]
                )
                and len(details.tokens) == len(other.tokens)
                and all([eq_token(d, o) for d, o in zip(details.tokens, other.tokens)])
                and (
                    len(details.best_of_sequences)
                    if details.best_of_sequences is not None
                    else 0
                )
                == (
                    len(other.best_of_sequences)
                    if other.best_of_sequences is not None
                    else 0
                )
                and (
                    all(
                        [
                            eq_best_of(d, o)
                            for d, o in zip(
                                details.best_of_sequences, other.best_of_sequences
                            )
                        ]
                    )
                    if details.best_of_sequences is not None
                    else details.best_of_sequences == other.best_of_sequences
                )
            )

206
207
208
        def eq_completion(response: Completion, other: Completion) -> bool:
            return response.choices[0].text == other.choices[0].text

drbh's avatar
drbh committed
209
210
211
212
213
214
215
216
217
218
        def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
            return (
                response.choices[0].message.content == other.choices[0].message.content
            )

        def eq_chat_complete_chunk(
            response: ChatCompletionChunk, other: ChatCompletionChunk
        ) -> bool:
            return response.choices[0].delta.content == other.choices[0].delta.content

219
220
221
222
223
224
225
226
227
228
229
230
231
        def eq_response(response: Response, other: Response) -> bool:
            return response.generated_text == other.generated_text and eq_details(
                response.details, other.details
            )

        serialized_data = convert_data(serialized_data)
        snapshot_data = convert_data(snapshot_data)

        if not isinstance(serialized_data, List):
            serialized_data = [serialized_data]
        if not isinstance(snapshot_data, List):
            snapshot_data = [snapshot_data]

232
233
234
235
236
        if isinstance(serialized_data[0], Completion):
            return len(snapshot_data) == len(serialized_data) and all(
                [eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]
            )

drbh's avatar
drbh committed
237
238
239
240
241
242
243
244
245
246
247
248
249
        if isinstance(serialized_data[0], ChatComplete):
            return len(snapshot_data) == len(serialized_data) and all(
                [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
            )

        if isinstance(serialized_data[0], ChatCompletionChunk):
            return len(snapshot_data) == len(serialized_data) and all(
                [
                    eq_chat_complete_chunk(r, o)
                    for r, o in zip(serialized_data, snapshot_data)
                ]
            )

250
251
252
253
254
        return len(snapshot_data) == len(serialized_data) and all(
            [eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
        )


255
256
257
258
class GenerousResponseComparator(ResponseComparator):
    # Needed for GPTQ with exllama which has serious numerical fluctuations.
    rtol = 0.75

OlivierDehaene's avatar
OlivierDehaene committed
259

260
261
262
263
class IgnoreLogProbResponseComparator(ResponseComparator):
    ignore_logprob = True


264
265
class LauncherHandle:
    def __init__(self, port: int):
drbh's avatar
drbh committed
266
        self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
267
268
269
270
271
272
273
274
275
276
277
278
279

    def _inner_health(self):
        raise NotImplementedError

    async def health(self, timeout: int = 60):
        assert timeout > 0
        for _ in range(timeout):
            if not self._inner_health():
                raise RuntimeError("Launcher crashed")

            try:
                await self.client.generate("test")
                return
280
            except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
                time.sleep(1)
        raise RuntimeError("Health check failed")


class ContainerLauncherHandle(LauncherHandle):
    def __init__(self, docker_client, container_name, port: int):
        super(ContainerLauncherHandle, self).__init__(port)
        self.docker_client = docker_client
        self.container_name = container_name

    def _inner_health(self) -> bool:
        container = self.docker_client.containers.get(self.container_name)
        return container.status in ["running", "created"]


class ProcessLauncherHandle(LauncherHandle):
    def __init__(self, process, port: int):
        super(ProcessLauncherHandle, self).__init__(port)
        self.process = process

    def _inner_health(self) -> bool:
        return self.process.poll() is None


305
@pytest.fixture
306
307
def response_snapshot(snapshot):
    return snapshot.use_extension(ResponseComparator)
308

OlivierDehaene's avatar
OlivierDehaene committed
309

310
311
312
313
@pytest.fixture
def generous_response_snapshot(snapshot):
    return snapshot.use_extension(GenerousResponseComparator)

314

315
316
317
318
319
@pytest.fixture
def ignore_logprob_response_snapshot(snapshot):
    return snapshot.use_extension(IgnoreLogProbResponseComparator)


320
321
322
323
324
325
326
327
328
329
330
@pytest.fixture(scope="module")
def event_loop():
    loop = asyncio.get_event_loop()
    yield loop
    loop.close()


@pytest.fixture(scope="module")
def launcher(event_loop):
    @contextlib.contextmanager
    def local_launcher(
331
332
333
334
        model_id: str,
        num_shard: Optional[int] = None,
        quantize: Optional[str] = None,
        trust_remote_code: bool = False,
335
        use_flash_attention: bool = True,
drbh's avatar
drbh committed
336
        disable_grammar_support: bool = False,
OlivierDehaene's avatar
OlivierDehaene committed
337
        dtype: Optional[str] = None,
338
        revision: Optional[str] = None,
339
        max_input_length: Optional[int] = None,
Nicolas Patry's avatar
Nicolas Patry committed
340
        max_batch_prefill_tokens: Optional[int] = None,
341
        max_total_tokens: Optional[int] = None,
342
343
        lora_adapters: Optional[List[str]] = None,
        cuda_graphs: Optional[List[int]] = None,
344
    ):
345
346
        port = random.randint(8000, 10_000)
        master_port = random.randint(10_000, 20_000)
347

348
349
350
        shard_uds_path = (
            f"/tmp/tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}-server"
        )
351
352
353
354
355
356
357
358
359
360
361
362
363

        args = [
            "text-generation-launcher",
            "--model-id",
            model_id,
            "--port",
            str(port),
            "--master-port",
            str(master_port),
            "--shard-uds-path",
            shard_uds_path,
        ]

364
365
        env = os.environ

drbh's avatar
drbh committed
366
367
        if disable_grammar_support:
            args.append("--disable-grammar-support")
368
369
        if num_shard is not None:
            args.extend(["--num-shard", str(num_shard)])
370
        if quantize is not None:
371
            args.append("--quantize")
372
            args.append(quantize)
373
374
375
        if dtype is not None:
            args.append("--dtype")
            args.append(dtype)
376
377
378
        if revision is not None:
            args.append("--revision")
            args.append(revision)
379
380
        if trust_remote_code:
            args.append("--trust-remote-code")
381
382
383
        if max_input_length:
            args.append("--max-input-length")
            args.append(str(max_input_length))
Nicolas Patry's avatar
Nicolas Patry committed
384
385
386
        if max_batch_prefill_tokens:
            args.append("--max-batch-prefill-tokens")
            args.append(str(max_batch_prefill_tokens))
387
388
389
        if max_total_tokens:
            args.append("--max-total-tokens")
            args.append(str(max_total_tokens))
390
391
392
393
394
395
396
397
        if lora_adapters:
            args.append("--lora-adapters")
            args.append(",".join(lora_adapters))
        if cuda_graphs:
            args.append("--cuda-graphs")
            args.append(",".join(map(str, cuda_graphs)))

        print(" ".join(args), file=sys.stderr)
398

399
400
        env["LOG_LEVEL"] = "info,text_generation_router=debug"

401
402
403
        if not use_flash_attention:
            env["USE_FLASH_ATTENTION"] = "false"

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
        with tempfile.TemporaryFile("w+") as tmp:
            # We'll output stdout/stderr to a temporary file. Using a pipe
            # cause the process to block until stdout is read.
            with subprocess.Popen(
                args,
                stdout=tmp,
                stderr=subprocess.STDOUT,
                env=env,
            ) as process:
                yield ProcessLauncherHandle(process, port)

                process.terminate()
                process.wait(60)

                tmp.seek(0)
                shutil.copyfileobj(tmp, sys.stderr)
420

421
422
423
        if not use_flash_attention:
            del env["USE_FLASH_ATTENTION"]

424
425
    @contextlib.contextmanager
    def docker_launcher(
426
427
428
429
        model_id: str,
        num_shard: Optional[int] = None,
        quantize: Optional[str] = None,
        trust_remote_code: bool = False,
430
        use_flash_attention: bool = True,
drbh's avatar
drbh committed
431
        disable_grammar_support: bool = False,
OlivierDehaene's avatar
OlivierDehaene committed
432
        dtype: Optional[str] = None,
433
        revision: Optional[str] = None,
434
        max_input_length: Optional[int] = None,
Nicolas Patry's avatar
Nicolas Patry committed
435
        max_batch_prefill_tokens: Optional[int] = None,
436
        max_total_tokens: Optional[int] = None,
437
438
        lora_adapters: Optional[List[str]] = None,
        cuda_graphs: Optional[List[int]] = None,
439
    ):
440
        port = random.randint(8000, 10_000)
441
442
443

        args = ["--model-id", model_id, "--env"]

drbh's avatar
drbh committed
444
445
        if disable_grammar_support:
            args.append("--disable-grammar-support")
446
447
        if num_shard is not None:
            args.extend(["--num-shard", str(num_shard)])
448
        if quantize is not None:
449
            args.append("--quantize")
450
            args.append(quantize)
451
452
453
        if dtype is not None:
            args.append("--dtype")
            args.append(dtype)
454
455
456
        if revision is not None:
            args.append("--revision")
            args.append(revision)
457
458
        if trust_remote_code:
            args.append("--trust-remote-code")
459
460
461
        if max_input_length:
            args.append("--max-input-length")
            args.append(str(max_input_length))
Nicolas Patry's avatar
Nicolas Patry committed
462
463
464
        if max_batch_prefill_tokens:
            args.append("--max-batch-prefill-tokens")
            args.append(str(max_batch_prefill_tokens))
465
466
467
        if max_total_tokens:
            args.append("--max-total-tokens")
            args.append(str(max_total_tokens))
468
469
470
471
472
473
        if lora_adapters:
            args.append("--lora-adapters")
            args.append(",".join(lora_adapters))
        if cuda_graphs:
            args.append("--cuda-graphs")
            args.append(",".join(map(str, cuda_graphs)))
474
475
476
477
478
479
480
481
482
483
484
485
486
487

        client = docker.from_env()

        container_name = f"tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}"

        try:
            container = client.containers.get(container_name)
            container.stop()
            container.wait()
        except NotFound:
            pass

        gpu_count = num_shard if num_shard is not None else 1

488
489
490
        env = {
            "LOG_LEVEL": "info,text_generation_router=debug",
        }
491
492
493
        if not use_flash_attention:
            env["USE_FLASH_ATTENTION"] = "false"

494
495
        if HF_TOKEN is not None:
            env["HF_TOKEN"] = HF_TOKEN
496
497
498
499
500

        volumes = []
        if DOCKER_VOLUME:
            volumes = [f"{DOCKER_VOLUME}:/data"]

501
502
503
504
505
506
507
508
509
510
511
512
        if DOCKER_DEVICES:
            devices = DOCKER_DEVICES.split(",")
            visible = os.getenv("ROCR_VISIBLE_DEVICES")
            if visible:
                env["ROCR_VISIBLE_DEVICES"] = visible
            device_requests = []
        else:
            devices = []
            device_requests = [
                docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
            ]

513
514
515
516
517
        container = client.containers.run(
            DOCKER_IMAGE,
            command=args,
            name=container_name,
            environment=env,
518
            auto_remove=False,
519
            detach=True,
520
521
            device_requests=device_requests,
            devices=devices,
522
523
            volumes=volumes,
            ports={"80/tcp": port},
OlivierDehaene's avatar
OlivierDehaene committed
524
            shm_size="1G",
525
526
        )

527
        yield ContainerLauncherHandle(client, container.name, port)
528

529
530
531
        if not use_flash_attention:
            del env["USE_FLASH_ATTENTION"]

532
533
534
535
536
        try:
            container.stop()
            container.wait()
        except NotFound:
            pass
537
538

        container_output = container.logs().decode("utf-8")
539
        print(container_output, file=sys.stderr)
540

541
542
        container.remove()

543
544
545
546
547
548
549
550
    if DOCKER_IMAGE is not None:
        return docker_launcher
    return local_launcher


@pytest.fixture(scope="module")
def generate_load():
    async def generate_load_inner(
drbh's avatar
drbh committed
551
552
553
554
555
556
557
        client: AsyncClient,
        prompt: str,
        max_new_tokens: int,
        n: int,
        seed: Optional[int] = None,
        grammar: Optional[Grammar] = None,
        stop_sequences: Optional[List[str]] = None,
558
559
    ) -> List[Response]:
        futures = [
560
            client.generate(
drbh's avatar
drbh committed
561
562
563
564
565
566
                prompt,
                max_new_tokens=max_new_tokens,
                decoder_input_details=True,
                seed=seed,
                grammar=grammar,
                stop_sequences=stop_sequences,
567
568
            )
            for _ in range(n)
569
570
        ]

571
        return await asyncio.gather(*futures)
572
573

    return generate_load_inner