conftest.py 3.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import subprocess
import contextlib
import pytest
import asyncio
import os
import docker

from docker.errors import NotFound
from typing import Optional, List
from syrupy.filters import props

from text_generation import AsyncClient
from text_generation.types import Response

DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")


@pytest.fixture
def snapshot_test(snapshot):
    return lambda value: value == snapshot(exclude=props("logprob"))


@pytest.fixture(scope="module")
def event_loop():
    loop = asyncio.get_event_loop()
    yield loop
    loop.close()


@pytest.fixture(scope="module")
def launcher(event_loop):
    @contextlib.contextmanager
    def local_launcher(
        model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
    ):
        port = 9999
        master_port = 19999

        shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server"

        args = [
            "text-generation-launcher",
            "--model-id",
            model_id,
            "--port",
            str(port),
            "--master-port",
            str(master_port),
            "--shard-uds-path",
            shard_uds_path,
        ]

        if num_shard is not None:
            args.extend(["--num-shard", str(num_shard)])
        if quantize:
            args.append("--quantize")

        with subprocess.Popen(
            args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
        ) as process:
            yield AsyncClient(f"http://localhost:{port}")

            process.terminate()
            process.wait(60)

            launcher_output = process.stdout.read().decode("utf-8")
            print(launcher_output)

            process.stdout.close()
            process.stderr.close()

    @contextlib.contextmanager
    def docker_launcher(
        model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
    ):
        port = 9999

        args = ["--model-id", model_id, "--env"]

        if num_shard is not None:
            args.extend(["--num-shard", str(num_shard)])
        if quantize:
            args.append("--quantize")

        client = docker.from_env()

        container_name = f"tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}"

        try:
            container = client.containers.get(container_name)
            container.stop()
            container.wait()
        except NotFound:
            pass

        gpu_count = num_shard if num_shard is not None else 1

        env = {}
        if HUGGING_FACE_HUB_TOKEN is not None:
            env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN

        volumes = []
        if DOCKER_VOLUME:
            volumes = [f"{DOCKER_VOLUME}:/data"]

        container = client.containers.run(
            DOCKER_IMAGE,
            command=args,
            name=container_name,
            environment=env,
            auto_remove=True,
            detach=True,
            device_requests=[
                docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
            ],
            volumes=volumes,
            ports={"80/tcp": port},
        )

        yield AsyncClient(f"http://localhost:{port}")

        container.stop()

        container_output = container.logs().decode("utf-8")
        print(container_output)

    if DOCKER_IMAGE is not None:
        return docker_launcher
    return local_launcher


@pytest.fixture(scope="module")
def generate_load():
    async def generate_load_inner(
        client: AsyncClient, prompt: str, max_new_tokens: int, n: int
    ) -> List[Response]:
        futures = [
            client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)
        ]

        results = await asyncio.gather(*futures)
        return [r.dict() for r in results]

    return generate_load_inner