"vscode:/vscode.git/clone" did not exist on "2236a93efccb2aa8d907225c182d54ffd2e90e11"
conftest.py 21.3 KB
Newer Older
1
import asyncio
2
import contextlib
3
4
import json
import math
5
6
import os
import random
7
import shutil
8
9
import subprocess
import sys
10
import tempfile
11
import time
12
13
import docker
import pytest
14
15
16
17
import base64

from pathlib import Path
from typing import Dict, List, Optional
18
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
19
from docker.errors import NotFound
20
from syrupy.extensions.json import JSONSnapshotExtension
21

22
from text_generation import AsyncClient
drbh's avatar
drbh committed
23
24
from text_generation.types import (
    BestOfSequence,
25
    Message,
drbh's avatar
drbh committed
26
27
    ChatComplete,
    ChatCompletionChunk,
28
    ChatCompletionComplete,
29
    Completion,
30
31
32
33
34
    Details,
    Grammar,
    InputToken,
    Response,
    Token,
drbh's avatar
drbh committed
35
)
36
37

DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
38
HF_TOKEN = os.getenv("HF_TOKEN", None)
39
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
40
DOCKER_DEVICES = os.getenv("DOCKER_DEVICES")
41
42


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def pytest_addoption(parser):
    parser.addoption(
        "--release", action="store_true", default=False, help="run release tests"
    )


def pytest_configure(config):
    config.addinivalue_line("markers", "release: mark test as a release-only test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--release"):
        # --release given in cli: do not skip release tests
        return
    skip_release = pytest.mark.skip(reason="need --release option to run")
    for item in items:
        if "release" in item.keywords:
            item.add_marker(skip_release)


63
class ResponseComparator(JSONSnapshotExtension):
64
    rtol = 0.2
65
    ignore_logprob = False
OlivierDehaene's avatar
OlivierDehaene committed
66

67
68
69
70
    def serialize(
        self,
        data,
        *,
71
        include=None,
72
73
74
        exclude=None,
        matcher=None,
    ):
75
76
77
78
79
80
81
        if (
            isinstance(data, Response)
            or isinstance(data, ChatComplete)
            or isinstance(data, ChatCompletionChunk)
            or isinstance(data, ChatCompletionComplete)
        ):
            data = data.model_dump()
82

83
        if isinstance(data, List):
84
            data = [d.model_dump() for d in data]
85
86

        data = self._filter(
87
88
89
90
91
92
            data=data,
            depth=0,
            path=(),
            exclude=exclude,
            include=include,
            matcher=matcher,
93
94
95
96
97
98
99
100
101
102
103
        )
        return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + "\n"

    def matches(
        self,
        *,
        serialized_data,
        snapshot_data,
    ) -> bool:
        def convert_data(data):
            data = json.loads(data)
104
            return _convert_data(data)
105

106
        def _convert_data(data):
107
            if isinstance(data, Dict):
108
109
110
111
112
113
114
115
116
117
118
119
120
                if "choices" in data:
                    data["choices"] = list(
                        sorted(data["choices"], key=lambda x: x["index"])
                    )
                    choices = data["choices"]
                    if isinstance(choices, List) and len(choices) >= 1:
                        if "delta" in choices[0]:
                            return ChatCompletionChunk(**data)
                        if "text" in choices[0]:
                            return Completion(**data)
                    return ChatComplete(**data)
                else:
                    return Response(**data)
121
            if isinstance(data, List):
122
                return [_convert_data(d) for d in data]
123
124
125
126
127
128
            raise NotImplementedError

        def eq_token(token: Token, other: Token) -> bool:
            return (
                token.id == other.id
                and token.text == other.text
129
130
                and (
                    self.ignore_logprob
Nicolas Patry's avatar
Nicolas Patry committed
131
                    or (token.logprob == other.logprob and token.logprob is None)
132
133
                    or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
                )
134
135
136
                and token.special == other.special
            )

137
        def eq_prefill_token(prefill_token: InputToken, other: InputToken) -> bool:
138
139
140
141
142
            try:
                return (
                    prefill_token.id == other.id
                    and prefill_token.text == other.text
                    and (
143
144
145
146
147
                        self.ignore_logprob
                        or math.isclose(
                            prefill_token.logprob,
                            other.logprob,
                            rel_tol=self.rtol,
OlivierDehaene's avatar
OlivierDehaene committed
148
                        )
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
                        if prefill_token.logprob is not None
                        else prefill_token.logprob == other.logprob
                    )
                )
            except TypeError:
                return False

        def eq_best_of(details: BestOfSequence, other: BestOfSequence) -> bool:
            return (
                details.finish_reason == other.finish_reason
                and details.generated_tokens == other.generated_tokens
                and details.seed == other.seed
                and len(details.prefill) == len(other.prefill)
                and all(
                    [
                        eq_prefill_token(d, o)
                        for d, o in zip(details.prefill, other.prefill)
                    ]
                )
                and len(details.tokens) == len(other.tokens)
                and all([eq_token(d, o) for d, o in zip(details.tokens, other.tokens)])
            )

        def eq_details(details: Details, other: Details) -> bool:
            return (
                details.finish_reason == other.finish_reason
                and details.generated_tokens == other.generated_tokens
                and details.seed == other.seed
                and len(details.prefill) == len(other.prefill)
                and all(
                    [
                        eq_prefill_token(d, o)
                        for d, o in zip(details.prefill, other.prefill)
                    ]
                )
                and len(details.tokens) == len(other.tokens)
                and all([eq_token(d, o) for d, o in zip(details.tokens, other.tokens)])
                and (
                    len(details.best_of_sequences)
                    if details.best_of_sequences is not None
                    else 0
                )
                == (
                    len(other.best_of_sequences)
                    if other.best_of_sequences is not None
                    else 0
                )
                and (
                    all(
                        [
                            eq_best_of(d, o)
                            for d, o in zip(
                                details.best_of_sequences, other.best_of_sequences
                            )
                        ]
                    )
                    if details.best_of_sequences is not None
                    else details.best_of_sequences == other.best_of_sequences
                )
            )

210
211
212
        def eq_completion(response: Completion, other: Completion) -> bool:
            return response.choices[0].text == other.choices[0].text

drbh's avatar
drbh committed
213
214
215
216
217
218
219
220
221
222
        def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
            return (
                response.choices[0].message.content == other.choices[0].message.content
            )

        def eq_chat_complete_chunk(
            response: ChatCompletionChunk, other: ChatCompletionChunk
        ) -> bool:
            return response.choices[0].delta.content == other.choices[0].delta.content

223
224
225
226
227
228
229
230
231
232
233
234
235
        def eq_response(response: Response, other: Response) -> bool:
            return response.generated_text == other.generated_text and eq_details(
                response.details, other.details
            )

        serialized_data = convert_data(serialized_data)
        snapshot_data = convert_data(snapshot_data)

        if not isinstance(serialized_data, List):
            serialized_data = [serialized_data]
        if not isinstance(snapshot_data, List):
            snapshot_data = [snapshot_data]

236
237
238
239
240
        if isinstance(serialized_data[0], Completion):
            return len(snapshot_data) == len(serialized_data) and all(
                [eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]
            )

drbh's avatar
drbh committed
241
242
243
244
245
246
247
248
249
250
251
252
253
        if isinstance(serialized_data[0], ChatComplete):
            return len(snapshot_data) == len(serialized_data) and all(
                [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
            )

        if isinstance(serialized_data[0], ChatCompletionChunk):
            return len(snapshot_data) == len(serialized_data) and all(
                [
                    eq_chat_complete_chunk(r, o)
                    for r, o in zip(serialized_data, snapshot_data)
                ]
            )

254
255
256
257
258
        return len(snapshot_data) == len(serialized_data) and all(
            [eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
        )


259
260
261
262
class GenerousResponseComparator(ResponseComparator):
    # Needed for GPTQ with exllama which has serious numerical fluctuations.
    rtol = 0.75

OlivierDehaene's avatar
OlivierDehaene committed
263

264
265
266
267
class IgnoreLogProbResponseComparator(ResponseComparator):
    ignore_logprob = True


268
269
class LauncherHandle:
    def __init__(self, port: int):
drbh's avatar
drbh committed
270
        self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
271
272
273
274
275
276
277
278
279
280
281
282
283

    def _inner_health(self):
        raise NotImplementedError

    async def health(self, timeout: int = 60):
        assert timeout > 0
        for _ in range(timeout):
            if not self._inner_health():
                raise RuntimeError("Launcher crashed")

            try:
                await self.client.generate("test")
                return
284
            except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
                time.sleep(1)
        raise RuntimeError("Health check failed")


class ContainerLauncherHandle(LauncherHandle):
    def __init__(self, docker_client, container_name, port: int):
        super(ContainerLauncherHandle, self).__init__(port)
        self.docker_client = docker_client
        self.container_name = container_name

    def _inner_health(self) -> bool:
        container = self.docker_client.containers.get(self.container_name)
        return container.status in ["running", "created"]


class ProcessLauncherHandle(LauncherHandle):
    def __init__(self, process, port: int):
        super(ProcessLauncherHandle, self).__init__(port)
        self.process = process

    def _inner_health(self) -> bool:
        return self.process.poll() is None


309
@pytest.fixture
310
311
def response_snapshot(snapshot):
    return snapshot.use_extension(ResponseComparator)
312

OlivierDehaene's avatar
OlivierDehaene committed
313

314
315
316
317
@pytest.fixture
def generous_response_snapshot(snapshot):
    return snapshot.use_extension(GenerousResponseComparator)

318

319
320
321
322
323
@pytest.fixture
def ignore_logprob_response_snapshot(snapshot):
    return snapshot.use_extension(IgnoreLogProbResponseComparator)


324
325
326
327
328
329
330
331
332
333
334
@pytest.fixture(scope="module")
def event_loop():
    loop = asyncio.get_event_loop()
    yield loop
    loop.close()


@pytest.fixture(scope="module")
def launcher(event_loop):
    @contextlib.contextmanager
    def local_launcher(
335
336
337
338
        model_id: str,
        num_shard: Optional[int] = None,
        quantize: Optional[str] = None,
        trust_remote_code: bool = False,
339
        use_flash_attention: bool = True,
drbh's avatar
drbh committed
340
        disable_grammar_support: bool = False,
OlivierDehaene's avatar
OlivierDehaene committed
341
        dtype: Optional[str] = None,
342
        kv_cache_dtype: Optional[str] = None,
343
        revision: Optional[str] = None,
344
        max_input_length: Optional[int] = None,
Nicolas Patry's avatar
Nicolas Patry committed
345
        max_batch_prefill_tokens: Optional[int] = None,
346
        max_total_tokens: Optional[int] = None,
347
348
        lora_adapters: Optional[List[str]] = None,
        cuda_graphs: Optional[List[int]] = None,
Nicolas Patry's avatar
Nicolas Patry committed
349
        attention: Optional[str] = None,
350
    ):
351
352
        port = random.randint(8000, 10_000)
        master_port = random.randint(10_000, 20_000)
353

354
355
356
        shard_uds_path = (
            f"/tmp/tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}-server"
        )
357
358
359
360
361
362
363
364
365
366
367
368
369

        args = [
            "text-generation-launcher",
            "--model-id",
            model_id,
            "--port",
            str(port),
            "--master-port",
            str(master_port),
            "--shard-uds-path",
            shard_uds_path,
        ]

370
371
        env = os.environ

drbh's avatar
drbh committed
372
373
        if disable_grammar_support:
            args.append("--disable-grammar-support")
374
375
        if num_shard is not None:
            args.extend(["--num-shard", str(num_shard)])
376
        if quantize is not None:
377
            args.append("--quantize")
378
            args.append(quantize)
379
380
381
        if dtype is not None:
            args.append("--dtype")
            args.append(dtype)
382
383
384
        if kv_cache_dtype is not None:
            args.append("--kv-cache-dtype")
            args.append(kv_cache_dtype)
385
386
387
        if revision is not None:
            args.append("--revision")
            args.append(revision)
388
389
        if trust_remote_code:
            args.append("--trust-remote-code")
390
391
392
        if max_input_length:
            args.append("--max-input-length")
            args.append(str(max_input_length))
Nicolas Patry's avatar
Nicolas Patry committed
393
394
395
        if max_batch_prefill_tokens:
            args.append("--max-batch-prefill-tokens")
            args.append(str(max_batch_prefill_tokens))
396
397
398
        if max_total_tokens:
            args.append("--max-total-tokens")
            args.append(str(max_total_tokens))
399
400
401
402
403
404
405
406
        if lora_adapters:
            args.append("--lora-adapters")
            args.append(",".join(lora_adapters))
        if cuda_graphs:
            args.append("--cuda-graphs")
            args.append(",".join(map(str, cuda_graphs)))

        print(" ".join(args), file=sys.stderr)
407

408
        env["LOG_LEVEL"] = "info,text_generation_router=debug"
409
        env["PREFILL_CHUNKING"] = "1"
410

411
412
        if not use_flash_attention:
            env["USE_FLASH_ATTENTION"] = "false"
Nicolas Patry's avatar
Nicolas Patry committed
413
414
        if attention is not None:
            env["ATTENTION"] = attention
415

416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
        with tempfile.TemporaryFile("w+") as tmp:
            # We'll output stdout/stderr to a temporary file. Using a pipe
            # cause the process to block until stdout is read.
            with subprocess.Popen(
                args,
                stdout=tmp,
                stderr=subprocess.STDOUT,
                env=env,
            ) as process:
                yield ProcessLauncherHandle(process, port)

                process.terminate()
                process.wait(60)

                tmp.seek(0)
                shutil.copyfileobj(tmp, sys.stderr)
432

433
434
435
        if not use_flash_attention:
            del env["USE_FLASH_ATTENTION"]

436
437
    @contextlib.contextmanager
    def docker_launcher(
438
439
440
441
        model_id: str,
        num_shard: Optional[int] = None,
        quantize: Optional[str] = None,
        trust_remote_code: bool = False,
442
        use_flash_attention: bool = True,
drbh's avatar
drbh committed
443
        disable_grammar_support: bool = False,
OlivierDehaene's avatar
OlivierDehaene committed
444
        dtype: Optional[str] = None,
445
        kv_cache_dtype: Optional[str] = None,
446
        revision: Optional[str] = None,
447
        max_input_length: Optional[int] = None,
Nicolas Patry's avatar
Nicolas Patry committed
448
        max_batch_prefill_tokens: Optional[int] = None,
449
        max_total_tokens: Optional[int] = None,
450
451
        lora_adapters: Optional[List[str]] = None,
        cuda_graphs: Optional[List[int]] = None,
Nicolas Patry's avatar
Nicolas Patry committed
452
        attention: Optional[str] = None,
453
    ):
454
        port = random.randint(8000, 10_000)
455
456
457

        args = ["--model-id", model_id, "--env"]

drbh's avatar
drbh committed
458
459
        if disable_grammar_support:
            args.append("--disable-grammar-support")
460
461
        if num_shard is not None:
            args.extend(["--num-shard", str(num_shard)])
462
        if quantize is not None:
463
            args.append("--quantize")
464
            args.append(quantize)
465
466
467
        if dtype is not None:
            args.append("--dtype")
            args.append(dtype)
468
469
470
        if kv_cache_dtype is not None:
            args.append("--kv-cache-dtype")
            args.append(kv_cache_dtype)
471
472
473
        if revision is not None:
            args.append("--revision")
            args.append(revision)
474
475
        if trust_remote_code:
            args.append("--trust-remote-code")
476
477
478
        if max_input_length:
            args.append("--max-input-length")
            args.append(str(max_input_length))
Nicolas Patry's avatar
Nicolas Patry committed
479
480
481
        if max_batch_prefill_tokens:
            args.append("--max-batch-prefill-tokens")
            args.append(str(max_batch_prefill_tokens))
482
483
484
        if max_total_tokens:
            args.append("--max-total-tokens")
            args.append(str(max_total_tokens))
485
486
487
488
489
490
        if lora_adapters:
            args.append("--lora-adapters")
            args.append(",".join(lora_adapters))
        if cuda_graphs:
            args.append("--cuda-graphs")
            args.append(",".join(map(str, cuda_graphs)))
491
492
493
494
495
496
497
498

        client = docker.from_env()

        container_name = f"tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}"

        try:
            container = client.containers.get(container_name)
            container.stop()
Nicolas Patry's avatar
Nicolas Patry committed
499
            container.remove()
500
501
502
503
504
505
            container.wait()
        except NotFound:
            pass

        gpu_count = num_shard if num_shard is not None else 1

506
507
        env = {
            "LOG_LEVEL": "info,text_generation_router=debug",
508
            "PREFILL_CHUNKING": "1",
509
        }
510
511
        if not use_flash_attention:
            env["USE_FLASH_ATTENTION"] = "false"
Nicolas Patry's avatar
Nicolas Patry committed
512
513
        if attention is not None:
            env["ATTENTION"] = attention
514

515
516
        if HF_TOKEN is not None:
            env["HF_TOKEN"] = HF_TOKEN
517
518
519
520
521

        volumes = []
        if DOCKER_VOLUME:
            volumes = [f"{DOCKER_VOLUME}:/data"]

522
        if DOCKER_DEVICES:
Nicolas Patry's avatar
Nicolas Patry committed
523
524
525
526
            if DOCKER_DEVICES.lower() == "none":
                devices = []
            else:
                devices = DOCKER_DEVICES.strip().split(",")
527
528
529
530
            visible = os.getenv("ROCR_VISIBLE_DEVICES")
            if visible:
                env["ROCR_VISIBLE_DEVICES"] = visible
            device_requests = []
Nicolas Patry's avatar
Nicolas Patry committed
531
532
533
534
535
536
537
538
539
540
541
542
            if not devices:
                devices = None
            elif devices == ["nvidia.com/gpu=all"]:
                devices = None
                device_requests = [
                    docker.types.DeviceRequest(
                        driver="cdi",
                        # count=gpu_count,
                        device_ids=[f"nvidia.com/gpu={i}"],
                    )
                    for i in range(gpu_count)
                ]
543
        else:
Nicolas Patry's avatar
Nicolas Patry committed
544
            devices = None
545
546
547
548
            device_requests = [
                docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
            ]

549
550
551
552
553
        container = client.containers.run(
            DOCKER_IMAGE,
            command=args,
            name=container_name,
            environment=env,
554
            auto_remove=False,
555
            detach=True,
556
557
            device_requests=device_requests,
            devices=devices,
558
559
            volumes=volumes,
            ports={"80/tcp": port},
Nicolas Patry's avatar
Nicolas Patry committed
560
            healthcheck={"timeout": int(10 * 1e9)},
OlivierDehaene's avatar
OlivierDehaene committed
561
            shm_size="1G",
562
563
        )

Nicolas Patry's avatar
Nicolas Patry committed
564
565
        try:
            yield ContainerLauncherHandle(client, container.name, port)
566

Nicolas Patry's avatar
Nicolas Patry committed
567
568
            if not use_flash_attention:
                del env["USE_FLASH_ATTENTION"]
569

Nicolas Patry's avatar
Nicolas Patry committed
570
571
572
573
574
            try:
                container.stop()
                container.wait()
            except NotFound:
                pass
575

Nicolas Patry's avatar
Nicolas Patry committed
576
577
            container_output = container.logs().decode("utf-8")
            print(container_output, file=sys.stderr)
578

Nicolas Patry's avatar
Nicolas Patry committed
579
        finally:
Nicolas Patry's avatar
Nicolas Patry committed
580
581
582
583
            try:
                container.remove()
            except Exception:
                pass
584

585
586
587
588
589
590
591
592
    if DOCKER_IMAGE is not None:
        return docker_launcher
    return local_launcher


@pytest.fixture(scope="module")
def generate_load():
    async def generate_load_inner(
drbh's avatar
drbh committed
593
594
595
596
597
598
599
        client: AsyncClient,
        prompt: str,
        max_new_tokens: int,
        n: int,
        seed: Optional[int] = None,
        grammar: Optional[Grammar] = None,
        stop_sequences: Optional[List[str]] = None,
600
601
    ) -> List[Response]:
        futures = [
602
            client.generate(
drbh's avatar
drbh committed
603
604
605
606
607
608
                prompt,
                max_new_tokens=max_new_tokens,
                decoder_input_details=True,
                seed=seed,
                grammar=grammar,
                stop_sequences=stop_sequences,
609
610
            )
            for _ in range(n)
611
612
        ]

613
        return await asyncio.gather(*futures)
614
615

    return generate_load_inner
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649


@pytest.fixture(scope="module")
def generate_multi():
    async def generate_load_inner(
        client: AsyncClient,
        prompts: List[str],
        max_new_tokens: int,
        seed: Optional[int] = None,
    ) -> List[Response]:
        import numpy as np

        arange = np.arange(len(prompts))
        perm = np.random.permutation(arange)
        rperm = [-1] * len(perm)
        for i, p in enumerate(perm):
            rperm[p] = i

        shuffled_prompts = [prompts[p] for p in perm]
        futures = [
            client.chat(
                messages=[Message(role="user", content=prompt)],
                max_tokens=max_new_tokens,
                temperature=0,
                seed=seed,
            )
            for prompt in shuffled_prompts
        ]

        shuffled_responses = await asyncio.gather(*futures)
        responses = [shuffled_responses[p] for p in rperm]
        return responses

    return generate_load_inner
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668


# TODO fix the server parsser to count inline image tokens correctly
@pytest.fixture
def chicken():
    path = Path(__file__).parent / "images" / "chicken_on_money.png"

    with open(path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read())
    return f"data:image/png;base64,{encoded_string.decode('utf-8')}"


@pytest.fixture
def cow_beach():
    path = Path(__file__).parent / "images" / "cow_beach.png"

    with open(path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read())
    return f"data:image/png;base64,{encoded_string.decode('utf-8')}"