conftest.py 21.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# ruff: noqa: E402
import requests


class SessionTimeoutFix(requests.Session):
    def request(self, *args, **kwargs):
        timeout = kwargs.pop("timeout", 120)
        return super().request(*args, **kwargs, timeout=timeout)


requests.sessions.Session = SessionTimeoutFix

13
import asyncio
14
import contextlib
15
16
import json
import math
17
18
import os
import random
19
import shutil
20
21
import subprocess
import sys
22
import tempfile
23
import time
24
25
import docker
import pytest
26
27
28
29
import base64

from pathlib import Path
from typing import Dict, List, Optional
30
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
31
from docker.errors import NotFound
32
from syrupy.extensions.json import JSONSnapshotExtension
33

34
from text_generation import AsyncClient
drbh's avatar
drbh committed
35
36
from text_generation.types import (
    BestOfSequence,
37
    Message,
drbh's avatar
drbh committed
38
39
    ChatComplete,
    ChatCompletionChunk,
40
    ChatCompletionComplete,
41
    Completion,
42
43
44
45
46
    Details,
    Grammar,
    InputToken,
    Response,
    Token,
drbh's avatar
drbh committed
47
)
48
49

DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
50
HF_TOKEN = os.getenv("HF_TOKEN", None)
51
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
52
DOCKER_DEVICES = os.getenv("DOCKER_DEVICES")
53
54


55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def pytest_addoption(parser):
    parser.addoption(
        "--release", action="store_true", default=False, help="run release tests"
    )


def pytest_configure(config):
    config.addinivalue_line("markers", "release: mark test as a release-only test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--release"):
        # --release given in cli: do not skip release tests
        return
    skip_release = pytest.mark.skip(reason="need --release option to run")
    for item in items:
        if "release" in item.keywords:
            item.add_marker(skip_release)


75
class ResponseComparator(JSONSnapshotExtension):
76
    rtol = 0.2
77
    ignore_logprob = False
OlivierDehaene's avatar
OlivierDehaene committed
78

79
80
81
82
    def serialize(
        self,
        data,
        *,
83
        include=None,
84
85
86
        exclude=None,
        matcher=None,
    ):
87
88
89
90
91
92
93
        if (
            isinstance(data, Response)
            or isinstance(data, ChatComplete)
            or isinstance(data, ChatCompletionChunk)
            or isinstance(data, ChatCompletionComplete)
        ):
            data = data.model_dump()
94

95
        if isinstance(data, List):
96
            data = [d.model_dump() for d in data]
97
98

        data = self._filter(
99
100
101
102
103
104
            data=data,
            depth=0,
            path=(),
            exclude=exclude,
            include=include,
            matcher=matcher,
105
106
107
108
109
110
111
112
113
114
115
        )
        return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + "\n"

    def matches(
        self,
        *,
        serialized_data,
        snapshot_data,
    ) -> bool:
        def convert_data(data):
            data = json.loads(data)
116
            return _convert_data(data)
117

118
        def _convert_data(data):
119
            if isinstance(data, Dict):
120
121
122
123
124
125
126
127
128
129
130
131
132
                if "choices" in data:
                    data["choices"] = list(
                        sorted(data["choices"], key=lambda x: x["index"])
                    )
                    choices = data["choices"]
                    if isinstance(choices, List) and len(choices) >= 1:
                        if "delta" in choices[0]:
                            return ChatCompletionChunk(**data)
                        if "text" in choices[0]:
                            return Completion(**data)
                    return ChatComplete(**data)
                else:
                    return Response(**data)
133
            if isinstance(data, List):
134
                return [_convert_data(d) for d in data]
135
136
137
138
139
140
            raise NotImplementedError

        def eq_token(token: Token, other: Token) -> bool:
            return (
                token.id == other.id
                and token.text == other.text
141
142
                and (
                    self.ignore_logprob
Nicolas Patry's avatar
Nicolas Patry committed
143
                    or (token.logprob == other.logprob and token.logprob is None)
144
145
                    or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
                )
146
147
148
                and token.special == other.special
            )

149
        def eq_prefill_token(prefill_token: InputToken, other: InputToken) -> bool:
150
151
152
153
154
            try:
                return (
                    prefill_token.id == other.id
                    and prefill_token.text == other.text
                    and (
155
156
157
158
159
                        self.ignore_logprob
                        or math.isclose(
                            prefill_token.logprob,
                            other.logprob,
                            rel_tol=self.rtol,
OlivierDehaene's avatar
OlivierDehaene committed
160
                        )
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
                        if prefill_token.logprob is not None
                        else prefill_token.logprob == other.logprob
                    )
                )
            except TypeError:
                return False

        def eq_best_of(details: BestOfSequence, other: BestOfSequence) -> bool:
            return (
                details.finish_reason == other.finish_reason
                and details.generated_tokens == other.generated_tokens
                and details.seed == other.seed
                and len(details.prefill) == len(other.prefill)
                and all(
                    [
                        eq_prefill_token(d, o)
                        for d, o in zip(details.prefill, other.prefill)
                    ]
                )
                and len(details.tokens) == len(other.tokens)
                and all([eq_token(d, o) for d, o in zip(details.tokens, other.tokens)])
            )

        def eq_details(details: Details, other: Details) -> bool:
            return (
                details.finish_reason == other.finish_reason
                and details.generated_tokens == other.generated_tokens
                and details.seed == other.seed
                and len(details.prefill) == len(other.prefill)
                and all(
                    [
                        eq_prefill_token(d, o)
                        for d, o in zip(details.prefill, other.prefill)
                    ]
                )
                and len(details.tokens) == len(other.tokens)
                and all([eq_token(d, o) for d, o in zip(details.tokens, other.tokens)])
                and (
                    len(details.best_of_sequences)
                    if details.best_of_sequences is not None
                    else 0
                )
                == (
                    len(other.best_of_sequences)
                    if other.best_of_sequences is not None
                    else 0
                )
                and (
                    all(
                        [
                            eq_best_of(d, o)
                            for d, o in zip(
                                details.best_of_sequences, other.best_of_sequences
                            )
                        ]
                    )
                    if details.best_of_sequences is not None
                    else details.best_of_sequences == other.best_of_sequences
                )
            )

222
223
224
        def eq_completion(response: Completion, other: Completion) -> bool:
            return response.choices[0].text == other.choices[0].text

drbh's avatar
drbh committed
225
226
227
228
229
230
231
232
233
234
        def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
            return (
                response.choices[0].message.content == other.choices[0].message.content
            )

        def eq_chat_complete_chunk(
            response: ChatCompletionChunk, other: ChatCompletionChunk
        ) -> bool:
            return response.choices[0].delta.content == other.choices[0].delta.content

235
236
237
238
239
240
241
242
243
244
245
246
247
        def eq_response(response: Response, other: Response) -> bool:
            return response.generated_text == other.generated_text and eq_details(
                response.details, other.details
            )

        serialized_data = convert_data(serialized_data)
        snapshot_data = convert_data(snapshot_data)

        if not isinstance(serialized_data, List):
            serialized_data = [serialized_data]
        if not isinstance(snapshot_data, List):
            snapshot_data = [snapshot_data]

248
249
250
251
252
        if isinstance(serialized_data[0], Completion):
            return len(snapshot_data) == len(serialized_data) and all(
                [eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]
            )

drbh's avatar
drbh committed
253
254
255
256
257
258
259
260
261
262
263
264
265
        if isinstance(serialized_data[0], ChatComplete):
            return len(snapshot_data) == len(serialized_data) and all(
                [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
            )

        if isinstance(serialized_data[0], ChatCompletionChunk):
            return len(snapshot_data) == len(serialized_data) and all(
                [
                    eq_chat_complete_chunk(r, o)
                    for r, o in zip(serialized_data, snapshot_data)
                ]
            )

266
267
268
269
270
        return len(snapshot_data) == len(serialized_data) and all(
            [eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
        )


271
272
273
274
class GenerousResponseComparator(ResponseComparator):
    # Needed for GPTQ with exllama which has serious numerical fluctuations.
    rtol = 0.75

OlivierDehaene's avatar
OlivierDehaene committed
275

276
277
278
279
class IgnoreLogProbResponseComparator(ResponseComparator):
    ignore_logprob = True


280
281
class LauncherHandle:
    def __init__(self, port: int):
drbh's avatar
drbh committed
282
        self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
283
284
285
286
287
288
289
290
291
292
293
294
295

    def _inner_health(self):
        raise NotImplementedError

    async def health(self, timeout: int = 60):
        assert timeout > 0
        for _ in range(timeout):
            if not self._inner_health():
                raise RuntimeError("Launcher crashed")

            try:
                await self.client.generate("test")
                return
296
            except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
                time.sleep(1)
        raise RuntimeError("Health check failed")


class ContainerLauncherHandle(LauncherHandle):
    def __init__(self, docker_client, container_name, port: int):
        super(ContainerLauncherHandle, self).__init__(port)
        self.docker_client = docker_client
        self.container_name = container_name

    def _inner_health(self) -> bool:
        container = self.docker_client.containers.get(self.container_name)
        return container.status in ["running", "created"]


class ProcessLauncherHandle(LauncherHandle):
    def __init__(self, process, port: int):
        super(ProcessLauncherHandle, self).__init__(port)
        self.process = process

    def _inner_health(self) -> bool:
        return self.process.poll() is None


321
@pytest.fixture
322
323
def response_snapshot(snapshot):
    return snapshot.use_extension(ResponseComparator)
324

OlivierDehaene's avatar
OlivierDehaene committed
325

326
327
328
329
@pytest.fixture
def generous_response_snapshot(snapshot):
    return snapshot.use_extension(GenerousResponseComparator)

330

331
332
333
334
335
@pytest.fixture
def ignore_logprob_response_snapshot(snapshot):
    return snapshot.use_extension(IgnoreLogProbResponseComparator)


336
337
338
339
340
341
342
343
344
345
346
@pytest.fixture(scope="module")
def event_loop():
    loop = asyncio.get_event_loop()
    yield loop
    loop.close()


@pytest.fixture(scope="module")
def launcher(event_loop):
    @contextlib.contextmanager
    def local_launcher(
347
348
349
350
        model_id: str,
        num_shard: Optional[int] = None,
        quantize: Optional[str] = None,
        trust_remote_code: bool = False,
351
        use_flash_attention: bool = True,
drbh's avatar
drbh committed
352
        disable_grammar_support: bool = False,
OlivierDehaene's avatar
OlivierDehaene committed
353
        dtype: Optional[str] = None,
354
        kv_cache_dtype: Optional[str] = None,
355
        revision: Optional[str] = None,
356
        max_input_length: Optional[int] = None,
Nicolas Patry's avatar
Nicolas Patry committed
357
        max_batch_prefill_tokens: Optional[int] = None,
358
        max_total_tokens: Optional[int] = None,
359
360
        lora_adapters: Optional[List[str]] = None,
        cuda_graphs: Optional[List[int]] = None,
Nicolas Patry's avatar
Nicolas Patry committed
361
        attention: Optional[str] = None,
362
    ):
363
364
        port = random.randint(8000, 10_000)
        master_port = random.randint(10_000, 20_000)
365

366
367
368
        shard_uds_path = (
            f"/tmp/tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}-server"
        )
369
370
371
372
373
374
375
376
377
378
379
380
381

        args = [
            "text-generation-launcher",
            "--model-id",
            model_id,
            "--port",
            str(port),
            "--master-port",
            str(master_port),
            "--shard-uds-path",
            shard_uds_path,
        ]

382
383
        env = os.environ

drbh's avatar
drbh committed
384
385
        if disable_grammar_support:
            args.append("--disable-grammar-support")
386
387
        if num_shard is not None:
            args.extend(["--num-shard", str(num_shard)])
388
        if quantize is not None:
389
            args.append("--quantize")
390
            args.append(quantize)
391
392
393
        if dtype is not None:
            args.append("--dtype")
            args.append(dtype)
394
395
396
        if kv_cache_dtype is not None:
            args.append("--kv-cache-dtype")
            args.append(kv_cache_dtype)
397
398
399
        if revision is not None:
            args.append("--revision")
            args.append(revision)
400
401
        if trust_remote_code:
            args.append("--trust-remote-code")
402
403
404
        if max_input_length:
            args.append("--max-input-length")
            args.append(str(max_input_length))
Nicolas Patry's avatar
Nicolas Patry committed
405
406
407
        if max_batch_prefill_tokens:
            args.append("--max-batch-prefill-tokens")
            args.append(str(max_batch_prefill_tokens))
408
409
410
        if max_total_tokens:
            args.append("--max-total-tokens")
            args.append(str(max_total_tokens))
411
412
413
414
415
416
417
418
        if lora_adapters:
            args.append("--lora-adapters")
            args.append(",".join(lora_adapters))
        if cuda_graphs:
            args.append("--cuda-graphs")
            args.append(",".join(map(str, cuda_graphs)))

        print(" ".join(args), file=sys.stderr)
419

420
        env["LOG_LEVEL"] = "info,text_generation_router=debug"
421
        env["PREFILL_CHUNKING"] = "1"
422

423
424
        if not use_flash_attention:
            env["USE_FLASH_ATTENTION"] = "false"
Nicolas Patry's avatar
Nicolas Patry committed
425
426
        if attention is not None:
            env["ATTENTION"] = attention
427

428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        with tempfile.TemporaryFile("w+") as tmp:
            # We'll output stdout/stderr to a temporary file. Using a pipe
            # cause the process to block until stdout is read.
            with subprocess.Popen(
                args,
                stdout=tmp,
                stderr=subprocess.STDOUT,
                env=env,
            ) as process:
                yield ProcessLauncherHandle(process, port)

                process.terminate()
                process.wait(60)

                tmp.seek(0)
                shutil.copyfileobj(tmp, sys.stderr)
444

445
446
447
        if not use_flash_attention:
            del env["USE_FLASH_ATTENTION"]

448
449
    @contextlib.contextmanager
    def docker_launcher(
450
451
452
453
        model_id: str,
        num_shard: Optional[int] = None,
        quantize: Optional[str] = None,
        trust_remote_code: bool = False,
454
        use_flash_attention: bool = True,
drbh's avatar
drbh committed
455
        disable_grammar_support: bool = False,
OlivierDehaene's avatar
OlivierDehaene committed
456
        dtype: Optional[str] = None,
457
        kv_cache_dtype: Optional[str] = None,
458
        revision: Optional[str] = None,
459
        max_input_length: Optional[int] = None,
Nicolas Patry's avatar
Nicolas Patry committed
460
        max_batch_prefill_tokens: Optional[int] = None,
461
        max_total_tokens: Optional[int] = None,
462
463
        lora_adapters: Optional[List[str]] = None,
        cuda_graphs: Optional[List[int]] = None,
Nicolas Patry's avatar
Nicolas Patry committed
464
        attention: Optional[str] = None,
465
    ):
466
        port = random.randint(8000, 10_000)
467
468
469

        args = ["--model-id", model_id, "--env"]

drbh's avatar
drbh committed
470
471
        if disable_grammar_support:
            args.append("--disable-grammar-support")
472
473
        if num_shard is not None:
            args.extend(["--num-shard", str(num_shard)])
474
        if quantize is not None:
475
            args.append("--quantize")
476
            args.append(quantize)
477
478
479
        if dtype is not None:
            args.append("--dtype")
            args.append(dtype)
480
481
482
        if kv_cache_dtype is not None:
            args.append("--kv-cache-dtype")
            args.append(kv_cache_dtype)
483
484
485
        if revision is not None:
            args.append("--revision")
            args.append(revision)
486
487
        if trust_remote_code:
            args.append("--trust-remote-code")
488
489
490
        if max_input_length:
            args.append("--max-input-length")
            args.append(str(max_input_length))
Nicolas Patry's avatar
Nicolas Patry committed
491
492
493
        if max_batch_prefill_tokens:
            args.append("--max-batch-prefill-tokens")
            args.append(str(max_batch_prefill_tokens))
494
495
496
        if max_total_tokens:
            args.append("--max-total-tokens")
            args.append(str(max_total_tokens))
497
498
499
500
501
502
        if lora_adapters:
            args.append("--lora-adapters")
            args.append(",".join(lora_adapters))
        if cuda_graphs:
            args.append("--cuda-graphs")
            args.append(",".join(map(str, cuda_graphs)))
503
504
505
506
507
508
509
510

        client = docker.from_env()

        container_name = f"tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}"

        try:
            container = client.containers.get(container_name)
            container.stop()
Nicolas Patry's avatar
Nicolas Patry committed
511
            container.remove()
512
513
514
515
516
517
            container.wait()
        except NotFound:
            pass

        gpu_count = num_shard if num_shard is not None else 1

518
519
        env = {
            "LOG_LEVEL": "info,text_generation_router=debug",
520
            "PREFILL_CHUNKING": "1",
521
        }
522
523
        if not use_flash_attention:
            env["USE_FLASH_ATTENTION"] = "false"
Nicolas Patry's avatar
Nicolas Patry committed
524
525
        if attention is not None:
            env["ATTENTION"] = attention
526

527
528
        if HF_TOKEN is not None:
            env["HF_TOKEN"] = HF_TOKEN
529
530
531
532
533

        volumes = []
        if DOCKER_VOLUME:
            volumes = [f"{DOCKER_VOLUME}:/data"]

534
        if DOCKER_DEVICES:
Nicolas Patry's avatar
Nicolas Patry committed
535
536
537
538
            if DOCKER_DEVICES.lower() == "none":
                devices = []
            else:
                devices = DOCKER_DEVICES.strip().split(",")
539
540
541
542
            visible = os.getenv("ROCR_VISIBLE_DEVICES")
            if visible:
                env["ROCR_VISIBLE_DEVICES"] = visible
            device_requests = []
Nicolas Patry's avatar
Nicolas Patry committed
543
544
545
546
547
548
549
550
551
552
553
554
            if not devices:
                devices = None
            elif devices == ["nvidia.com/gpu=all"]:
                devices = None
                device_requests = [
                    docker.types.DeviceRequest(
                        driver="cdi",
                        # count=gpu_count,
                        device_ids=[f"nvidia.com/gpu={i}"],
                    )
                    for i in range(gpu_count)
                ]
555
        else:
Nicolas Patry's avatar
Nicolas Patry committed
556
            devices = None
557
558
559
560
            device_requests = [
                docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
            ]

561
562
563
564
565
        container = client.containers.run(
            DOCKER_IMAGE,
            command=args,
            name=container_name,
            environment=env,
566
            auto_remove=False,
567
            detach=True,
568
569
            device_requests=device_requests,
            devices=devices,
570
571
            volumes=volumes,
            ports={"80/tcp": port},
572
            healthcheck={"timeout": int(60 * 1e9), "retries": 2},  # 60s
OlivierDehaene's avatar
OlivierDehaene committed
573
            shm_size="1G",
574
575
        )

Nicolas Patry's avatar
Nicolas Patry committed
576
577
        try:
            yield ContainerLauncherHandle(client, container.name, port)
578

Nicolas Patry's avatar
Nicolas Patry committed
579
580
            if not use_flash_attention:
                del env["USE_FLASH_ATTENTION"]
581

Nicolas Patry's avatar
Nicolas Patry committed
582
583
584
585
586
            try:
                container.stop()
                container.wait()
            except NotFound:
                pass
587

Nicolas Patry's avatar
Nicolas Patry committed
588
589
            container_output = container.logs().decode("utf-8")
            print(container_output, file=sys.stderr)
590

Nicolas Patry's avatar
Nicolas Patry committed
591
        finally:
Nicolas Patry's avatar
Nicolas Patry committed
592
593
594
595
            try:
                container.remove()
            except Exception:
                pass
596

597
598
599
600
601
602
603
604
    if DOCKER_IMAGE is not None:
        return docker_launcher
    return local_launcher


@pytest.fixture(scope="module")
def generate_load():
    async def generate_load_inner(
drbh's avatar
drbh committed
605
606
607
608
609
610
611
        client: AsyncClient,
        prompt: str,
        max_new_tokens: int,
        n: int,
        seed: Optional[int] = None,
        grammar: Optional[Grammar] = None,
        stop_sequences: Optional[List[str]] = None,
612
613
    ) -> List[Response]:
        futures = [
614
            client.generate(
drbh's avatar
drbh committed
615
616
617
618
619
620
                prompt,
                max_new_tokens=max_new_tokens,
                decoder_input_details=True,
                seed=seed,
                grammar=grammar,
                stop_sequences=stop_sequences,
621
622
            )
            for _ in range(n)
623
624
        ]

625
        return await asyncio.gather(*futures)
626
627

    return generate_load_inner
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661


@pytest.fixture(scope="module")
def generate_multi():
    async def generate_load_inner(
        client: AsyncClient,
        prompts: List[str],
        max_new_tokens: int,
        seed: Optional[int] = None,
    ) -> List[Response]:
        import numpy as np

        arange = np.arange(len(prompts))
        perm = np.random.permutation(arange)
        rperm = [-1] * len(perm)
        for i, p in enumerate(perm):
            rperm[p] = i

        shuffled_prompts = [prompts[p] for p in perm]
        futures = [
            client.chat(
                messages=[Message(role="user", content=prompt)],
                max_tokens=max_new_tokens,
                temperature=0,
                seed=seed,
            )
            for prompt in shuffled_prompts
        ]

        shuffled_responses = await asyncio.gather(*futures)
        responses = [shuffled_responses[p] for p in rperm]
        return responses

    return generate_load_inner
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680


# TODO fix the server parsser to count inline image tokens correctly
@pytest.fixture
def chicken():
    path = Path(__file__).parent / "images" / "chicken_on_money.png"

    with open(path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read())
    return f"data:image/png;base64,{encoded_string.decode('utf-8')}"


@pytest.fixture
def cow_beach():
    path = Path(__file__).parent / "images" / "cow_beach.png"

    with open(path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read())
    return f"data:image/png;base64,{encoded_string.decode('utf-8')}"