"vscode:/vscode.git/clone" did not exist on "95cb38ae474848f5c5787916344a96def8c7ce81"
conftest.py 17.4 KB
Newer Older
1
import asyncio
xuxzh1's avatar
last  
xuxzh1 committed
2
import contextlib
3
4
import json
import math
xuxzh1's avatar
last  
xuxzh1 committed
5
import os
6
import random
7
import re
xuxzh1's avatar
last  
xuxzh1 committed
8
9
10
11
12
13
import shutil
import subprocess
import sys
import tempfile
import time
from typing import Dict, List, Optional
14

xuxzh1's avatar
last  
xuxzh1 committed
15
16
17
import docker
import pytest
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
18
from docker.errors import NotFound
19
from syrupy.extensions.json import JSONSnapshotExtension
20
from text_generation import AsyncClient
drbh's avatar
drbh committed
21
22
from text_generation.types import (
    BestOfSequence,
drbh's avatar
drbh committed
23
24
    ChatComplete,
    ChatCompletionChunk,
25
    ChatCompletionComplete,
26
    Completion,
xuxzh1's avatar
last  
xuxzh1 committed
27
28
29
30
31
    Details,
    Grammar,
    InputToken,
    Response,
    Token,
drbh's avatar
drbh committed
32
)
33
34

DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
xuxzh1's avatar
last  
xuxzh1 committed
35
HF_TOKEN = os.getenv("HF_TOKEN", None)
36
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
xuxzh1's avatar
last  
xuxzh1 committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
DOCKER_DEVICES = os.getenv("DOCKER_DEVICES")


def pytest_addoption(parser):
    parser.addoption(
        "--release", action="store_true", default=False, help="run release tests"
    )


def pytest_configure(config):
    config.addinivalue_line("markers", "release: mark test as a release-only test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--release"):
        # --release given in cli: do not skip release tests
        return
    skip_release = pytest.mark.skip(reason="need --release option to run")
    for item in items:
        if "release" in item.keywords:
            item.add_marker(skip_release)
58
59


60
class ResponseComparator(JSONSnapshotExtension):
61
    rtol = 0.2
xuxzh1's avatar
last  
xuxzh1 committed
62
    ignore_logprob = False
OlivierDehaene's avatar
OlivierDehaene committed
63

64
65
66
67
68
69
70
    def serialize(
        self,
        data,
        *,
        exclude=None,
        matcher=None,
    ):
71
72
73
74
75
76
77
        if (
            isinstance(data, Response)
            or isinstance(data, ChatComplete)
            or isinstance(data, ChatCompletionChunk)
            or isinstance(data, ChatCompletionComplete)
        ):
            data = data.model_dump()
78

79
        if isinstance(data, List):
80
            data = [d.model_dump() for d in data]
81
82
83
84
85
86
87
88
89
90
91
92
93
94

        data = self._filter(
            data=data, depth=0, path=(), exclude=exclude, matcher=matcher
        )
        return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + "\n"

    def matches(
        self,
        *,
        serialized_data,
        snapshot_data,
    ) -> bool:
        def convert_data(data):
            data = json.loads(data)
drbh's avatar
drbh committed
95
96
            if isinstance(data, Dict) and "choices" in data:
                choices = data["choices"]
97
98
99
100
101
                if isinstance(choices, List) and len(choices) >= 1:
                    if "delta" in choices[0]:
                        return ChatCompletionChunk(**data)
                    if "text" in choices[0]:
                        return Completion(**data)
drbh's avatar
drbh committed
102
                return ChatComplete(**data)
103
104
105
106

            if isinstance(data, Dict):
                return Response(**data)
            if isinstance(data, List):
107
108
109
110
111
112
                if (
                    len(data) > 0
                    and "object" in data[0]
                    and data[0]["object"] == "text_completion"
                ):
                    return [Completion(**d) for d in data]
113
114
115
116
117
118
119
                return [Response(**d) for d in data]
            raise NotImplementedError

        def eq_token(token: Token, other: Token) -> bool:
            return (
                token.id == other.id
                and token.text == other.text
xuxzh1's avatar
last  
xuxzh1 committed
120
121
122
123
                and (
                    self.ignore_logprob
                    or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
                )
124
125
126
                and token.special == other.special
            )

127
        def eq_prefill_token(prefill_token: InputToken, other: InputToken) -> bool:
128
129
130
131
132
            try:
                return (
                    prefill_token.id == other.id
                    and prefill_token.text == other.text
                    and (
xuxzh1's avatar
last  
xuxzh1 committed
133
134
135
136
137
                        self.ignore_logprob
                        or math.isclose(
                            prefill_token.logprob,
                            other.logprob,
                            rel_tol=self.rtol,
OlivierDehaene's avatar
OlivierDehaene committed
138
                        )
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
                        if prefill_token.logprob is not None
                        else prefill_token.logprob == other.logprob
                    )
                )
            except TypeError:
                return False

        def eq_best_of(details: BestOfSequence, other: BestOfSequence) -> bool:
            return (
                details.finish_reason == other.finish_reason
                and details.generated_tokens == other.generated_tokens
                and details.seed == other.seed
                and len(details.prefill) == len(other.prefill)
                and all(
                    [
                        eq_prefill_token(d, o)
                        for d, o in zip(details.prefill, other.prefill)
                    ]
                )
                and len(details.tokens) == len(other.tokens)
                and all([eq_token(d, o) for d, o in zip(details.tokens, other.tokens)])
            )

        def eq_details(details: Details, other: Details) -> bool:
            return (
                details.finish_reason == other.finish_reason
                and details.generated_tokens == other.generated_tokens
                and details.seed == other.seed
                and len(details.prefill) == len(other.prefill)
                and all(
                    [
                        eq_prefill_token(d, o)
                        for d, o in zip(details.prefill, other.prefill)
                    ]
                )
                and len(details.tokens) == len(other.tokens)
                and all([eq_token(d, o) for d, o in zip(details.tokens, other.tokens)])
                and (
                    len(details.best_of_sequences)
                    if details.best_of_sequences is not None
                    else 0
                )
                == (
                    len(other.best_of_sequences)
                    if other.best_of_sequences is not None
                    else 0
                )
                and (
                    all(
                        [
                            eq_best_of(d, o)
                            for d, o in zip(
                                details.best_of_sequences, other.best_of_sequences
                            )
                        ]
                    )
                    if details.best_of_sequences is not None
                    else details.best_of_sequences == other.best_of_sequences
                )
            )

200
201
202
        def eq_completion(response: Completion, other: Completion) -> bool:
            return response.choices[0].text == other.choices[0].text

drbh's avatar
drbh committed
203
204
205
206
207
208
209
210
211
212
        def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
            return (
                response.choices[0].message.content == other.choices[0].message.content
            )

        def eq_chat_complete_chunk(
            response: ChatCompletionChunk, other: ChatCompletionChunk
        ) -> bool:
            return response.choices[0].delta.content == other.choices[0].delta.content

213
214
215
216
217
218
219
220
221
222
223
224
225
        def eq_response(response: Response, other: Response) -> bool:
            return response.generated_text == other.generated_text and eq_details(
                response.details, other.details
            )

        serialized_data = convert_data(serialized_data)
        snapshot_data = convert_data(snapshot_data)

        if not isinstance(serialized_data, List):
            serialized_data = [serialized_data]
        if not isinstance(snapshot_data, List):
            snapshot_data = [snapshot_data]

226
227
228
229
230
        if isinstance(serialized_data[0], Completion):
            return len(snapshot_data) == len(serialized_data) and all(
                [eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]
            )

drbh's avatar
drbh committed
231
232
233
234
235
236
237
238
239
240
241
242
243
        if isinstance(serialized_data[0], ChatComplete):
            return len(snapshot_data) == len(serialized_data) and all(
                [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
            )

        if isinstance(serialized_data[0], ChatCompletionChunk):
            return len(snapshot_data) == len(serialized_data) and all(
                [
                    eq_chat_complete_chunk(r, o)
                    for r, o in zip(serialized_data, snapshot_data)
                ]
            )

244
245
246
247
248
        return len(snapshot_data) == len(serialized_data) and all(
            [eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
        )


249
250
251
252
class GenerousResponseComparator(ResponseComparator):
    # Needed for GPTQ with exllama which has serious numerical fluctuations.
    rtol = 0.75

OlivierDehaene's avatar
OlivierDehaene committed
253

xuxzh1's avatar
last  
xuxzh1 committed
254
255
256
257
class IgnoreLogProbResponseComparator(ResponseComparator):
    ignore_logprob = True


258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
class LauncherHandle:
    def __init__(self, port: int):
        self.client = AsyncClient(f"http://localhost:{port}")

    def _inner_health(self):
        raise NotImplementedError

    async def health(self, timeout: int = 60):
        assert timeout > 0
        for _ in range(timeout):
            if not self._inner_health():
                raise RuntimeError("Launcher crashed")

            try:
                await self.client.generate("test")
                return
            except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
                time.sleep(1)
        raise RuntimeError("Health check failed")


class ContainerLauncherHandle(LauncherHandle):
    def __init__(self, docker_client, container_name, port: int):
        super(ContainerLauncherHandle, self).__init__(port)
        self.docker_client = docker_client
        self.container_name = container_name

    def _inner_health(self) -> bool:
        container = self.docker_client.containers.get(self.container_name)
        return container.status in ["running", "created"]


class ProcessLauncherHandle(LauncherHandle):
    def __init__(self, process, port: int):
        super(ProcessLauncherHandle, self).__init__(port)
        self.process = process

    def _inner_health(self) -> bool:
        return self.process.poll() is None


299
@pytest.fixture
300
301
def response_snapshot(snapshot):
    return snapshot.use_extension(ResponseComparator)
302

OlivierDehaene's avatar
OlivierDehaene committed
303

304
305
306
307
@pytest.fixture
def generous_response_snapshot(snapshot):
    return snapshot.use_extension(GenerousResponseComparator)

308

xuxzh1's avatar
last  
xuxzh1 committed
309
310
311
312
313
@pytest.fixture
def ignore_logprob_response_snapshot(snapshot):
    return snapshot.use_extension(IgnoreLogProbResponseComparator)


314
315
316
317
318
319
320
321
322
323
324
@pytest.fixture(scope="module")
def event_loop():
    loop = asyncio.get_event_loop()
    yield loop
    loop.close()


@pytest.fixture(scope="module")
def launcher(event_loop):
    @contextlib.contextmanager
    def local_launcher(
325
326
327
328
        model_id: str,
        num_shard: Optional[int] = None,
        quantize: Optional[str] = None,
        trust_remote_code: bool = False,
329
        use_flash_attention: bool = True,
drbh's avatar
drbh committed
330
        disable_grammar_support: bool = False,
OlivierDehaene's avatar
OlivierDehaene committed
331
        dtype: Optional[str] = None,
332
        revision: Optional[str] = None,
333
        max_input_length: Optional[int] = None,
Nicolas Patry's avatar
Nicolas Patry committed
334
        max_batch_prefill_tokens: Optional[int] = None,
335
        max_total_tokens: Optional[int] = None,
336
    ):
337
338
        port = random.randint(8000, 10_000)
        master_port = random.randint(10_000, 20_000)
339

340
341
342
        shard_uds_path = (
            f"/tmp/tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}-server"
        )
343
344
345
346
347
348
349
350
351
352
353
354
355

        args = [
            "text-generation-launcher",
            "--model-id",
            model_id,
            "--port",
            str(port),
            "--master-port",
            str(master_port),
            "--shard-uds-path",
            shard_uds_path,
        ]

356
357
        env = os.environ

drbh's avatar
drbh committed
358
359
        if disable_grammar_support:
            args.append("--disable-grammar-support")
360
361
        if num_shard is not None:
            args.extend(["--num-shard", str(num_shard)])
362
        if quantize is not None:
363
            args.append("--quantize")
364
            args.append(quantize)
365
366
367
        if dtype is not None:
            args.append("--dtype")
            args.append(dtype)
368
369
370
        if revision is not None:
            args.append("--revision")
            args.append(revision)
371
372
        if trust_remote_code:
            args.append("--trust-remote-code")
373
374
375
        if max_input_length:
            args.append("--max-input-length")
            args.append(str(max_input_length))
Nicolas Patry's avatar
Nicolas Patry committed
376
377
378
        if max_batch_prefill_tokens:
            args.append("--max-batch-prefill-tokens")
            args.append(str(max_batch_prefill_tokens))
379
380
381
        if max_total_tokens:
            args.append("--max-total-tokens")
            args.append(str(max_total_tokens))
382

383
384
        env["LOG_LEVEL"] = "info,text_generation_router=debug"

385
386
387
        if not use_flash_attention:
            env["USE_FLASH_ATTENTION"] = "false"

xuxzh1's avatar
last  
xuxzh1 committed
388
389
390
391
392
393
394
395
396
397
        with tempfile.TemporaryFile("w+") as tmp:
            # We'll output stdout/stderr to a temporary file. Using a pipe
            # cause the process to block until stdout is read.
            with subprocess.Popen(
                args,
                stdout=tmp,
                stderr=subprocess.STDOUT,
                env=env,
            ) as process:
                yield ProcessLauncherHandle(process, port)
398

xuxzh1's avatar
last  
xuxzh1 committed
399
400
                process.terminate()
                process.wait(60)
401

xuxzh1's avatar
last  
xuxzh1 committed
402
403
                tmp.seek(0)
                shutil.copyfileobj(tmp, sys.stderr)
404

405
406
407
        if not use_flash_attention:
            del env["USE_FLASH_ATTENTION"]

408
409
    @contextlib.contextmanager
    def docker_launcher(
410
411
412
413
        model_id: str,
        num_shard: Optional[int] = None,
        quantize: Optional[str] = None,
        trust_remote_code: bool = False,
414
        use_flash_attention: bool = True,
drbh's avatar
drbh committed
415
        disable_grammar_support: bool = False,
OlivierDehaene's avatar
OlivierDehaene committed
416
        dtype: Optional[str] = None,
417
        revision: Optional[str] = None,
418
        max_input_length: Optional[int] = None,
Nicolas Patry's avatar
Nicolas Patry committed
419
        max_batch_prefill_tokens: Optional[int] = None,
420
        max_total_tokens: Optional[int] = None,
421
    ):
422
        port = random.randint(8000, 10_000)
423
424
425

        args = ["--model-id", model_id, "--env"]

drbh's avatar
drbh committed
426
427
        if disable_grammar_support:
            args.append("--disable-grammar-support")
428
429
        if num_shard is not None:
            args.extend(["--num-shard", str(num_shard)])
430
        if quantize is not None:
431
            args.append("--quantize")
432
            args.append(quantize)
433
434
435
        if dtype is not None:
            args.append("--dtype")
            args.append(dtype)
436
437
438
        if revision is not None:
            args.append("--revision")
            args.append(revision)
439
440
        if trust_remote_code:
            args.append("--trust-remote-code")
441
442
443
        if max_input_length:
            args.append("--max-input-length")
            args.append(str(max_input_length))
Nicolas Patry's avatar
Nicolas Patry committed
444
445
446
        if max_batch_prefill_tokens:
            args.append("--max-batch-prefill-tokens")
            args.append(str(max_batch_prefill_tokens))
447
448
449
        if max_total_tokens:
            args.append("--max-total-tokens")
            args.append(str(max_total_tokens))
450
451
452
453
454
455
456
457
458
459
460
461
462
463

        client = docker.from_env()

        container_name = f"tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}"

        try:
            container = client.containers.get(container_name)
            container.stop()
            container.wait()
        except NotFound:
            pass

        gpu_count = num_shard if num_shard is not None else 1

464
465
466
        env = {
            "LOG_LEVEL": "info,text_generation_router=debug",
        }
467
468
469
        if not use_flash_attention:
            env["USE_FLASH_ATTENTION"] = "false"

xuxzh1's avatar
last  
xuxzh1 committed
470
471
        if HF_TOKEN is not None:
            env["HF_TOKEN"] = HF_TOKEN
472
473
474
475
476

        volumes = []
        if DOCKER_VOLUME:
            volumes = [f"{DOCKER_VOLUME}:/data"]

xuxzh1's avatar
last  
xuxzh1 committed
477
478
479
480
481
482
483
484
485
486
487
488
        if DOCKER_DEVICES:
            devices = DOCKER_DEVICES.split(",")
            visible = os.getenv("ROCR_VISIBLE_DEVICES")
            if visible:
                env["ROCR_VISIBLE_DEVICES"] = visible
            device_requests = []
        else:
            devices = []
            device_requests = [
                docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
            ]

489
490
491
492
493
        container = client.containers.run(
            DOCKER_IMAGE,
            command=args,
            name=container_name,
            environment=env,
494
            auto_remove=False,
495
            detach=True,
xuxzh1's avatar
last  
xuxzh1 committed
496
497
            device_requests=device_requests,
            devices=devices,
498
499
            volumes=volumes,
            ports={"80/tcp": port},
OlivierDehaene's avatar
OlivierDehaene committed
500
            shm_size="1G",
501
502
        )

503
        yield ContainerLauncherHandle(client, container.name, port)
504

505
506
507
        if not use_flash_attention:
            del env["USE_FLASH_ATTENTION"]

508
509
510
511
512
        try:
            container.stop()
            container.wait()
        except NotFound:
            pass
513
514

        container_output = container.logs().decode("utf-8")
515
        print(container_output, file=sys.stderr)
516

517
518
        container.remove()

519
520
521
522
523
524
525
526
    if DOCKER_IMAGE is not None:
        return docker_launcher
    return local_launcher


@pytest.fixture(scope="module")
def generate_load():
    async def generate_load_inner(
drbh's avatar
drbh committed
527
528
529
530
531
532
533
        client: AsyncClient,
        prompt: str,
        max_new_tokens: int,
        n: int,
        seed: Optional[int] = None,
        grammar: Optional[Grammar] = None,
        stop_sequences: Optional[List[str]] = None,
534
535
    ) -> List[Response]:
        futures = [
536
            client.generate(
drbh's avatar
drbh committed
537
538
539
540
541
542
                prompt,
                max_new_tokens=max_new_tokens,
                decoder_input_details=True,
                seed=seed,
                grammar=grammar,
                stop_sequences=stop_sequences,
543
544
            )
            for _ in range(n)
545
546
        ]

547
        return await asyncio.gather(*futures)
548
549

    return generate_load_inner