test_tokenizer_group.py 8.1 KB
Newer Older
1
import asyncio
2
import os
3
4
import sys
from typing import List, Optional
5
6
from unittest.mock import patch

7
import pytest
8
from transformers import AutoTokenizer, PreTrainedTokenizerBase
9

10
11
from vllm.transformers_utils.tokenizer_group import (TokenizerGroup,
                                                     get_tokenizer_group)
12
13
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
    RayTokenizerGroupPool)
14

15
from ..conftest import get_tokenizer_pool_config
16
from ..utils import models_path_prefix
17
18


19
20
21
22
23
24
25
26
27
28
29
class CustomTokenizerGroup(TokenizerGroup):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._i = 0

    def encode(self, *args, **kwargs):
        self._i += 1
        return super().encode(*args, **kwargs)


30
@pytest.mark.asyncio
31
32
@pytest.mark.parametrize("tokenizer_group_type",
                         [None, "ray", CustomTokenizerGroup])
33
async def test_tokenizer_group(tokenizer_group_type):
34
    reference_tokenizer = AutoTokenizer.from_pretrained(os.path.join(models_path_prefix, "gpt2"))
35
36
    tokenizer_group = get_tokenizer_group(
        get_tokenizer_pool_config(tokenizer_group_type),
zhuwenwen's avatar
zhuwenwen committed
37
        tokenizer_id=os.path.join(models_path_prefix, "gpt2"),
38
39
40
41
42
43
44
45
46
47
48
49
50
        enable_lora=False,
        max_num_seqs=1,
        max_input_length=None,
    )
    assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
        request_id="request_id", prompt="prompt", lora_request=None)
    assert reference_tokenizer.encode(
        "prompt") == await tokenizer_group.encode_async(
            request_id="request_id", prompt="prompt", lora_request=None)
    assert isinstance(tokenizer_group.get_lora_tokenizer(None),
                      PreTrainedTokenizerBase)
    assert tokenizer_group.get_lora_tokenizer(
        None) == await tokenizer_group.get_lora_tokenizer_async(None)
51
52
    if tokenizer_group_type is CustomTokenizerGroup:
        assert tokenizer_group._i > 0
53
54
55
56
57


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_pool(tokenizer_group_type):
58
    reference_tokenizer = AutoTokenizer.from_pretrained(os.path.join(models_path_prefix, "gpt2"))
59
60
    tokenizer_group_pool = get_tokenizer_group(
        get_tokenizer_pool_config(tokenizer_group_type),
zhuwenwen's avatar
zhuwenwen committed
61
        tokenizer_id=os.path.join(models_path_prefix, "gpt2"),
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        enable_lora=False,
        max_num_seqs=1,
        max_input_length=None,
    )
    # Send multiple requests to the tokenizer group pool
    # (more than the pool size)
    # and check that all requests are processed correctly.
    num_requests = tokenizer_group_pool.pool_size * 5
    requests = [
        tokenizer_group_pool.encode_async(request_id=str(i),
                                          prompt=f"prompt {i}",
                                          lora_request=None)
        for i in range(num_requests)
    ]
    results = await asyncio.gather(*requests)
    expected_results = [
        reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests)
    ]
    assert results == expected_results


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_ray_pool_env_var_propagation(
        tokenizer_group_type):
    """Test that env vars from caller process are propagated to
    tokenizer Ray actors."""
    env_var = "MY_ENV_VAR"

    class EnvVarCheckerTokenizerGroup(TokenizerGroup):

        def ping(self):
            assert os.environ.get(env_var) == "1"
            return super().ping()

    class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool):
        _worker_cls = EnvVarCheckerTokenizerGroup

    tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
    tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
        tokenizer_pool_config,
zhuwenwen's avatar
zhuwenwen committed
103
        tokenizer_id=os.path.join(models_path_prefix, "gpt2"),
104
105
106
107
108
109
110
111
112
113
        enable_lora=False,
        max_num_seqs=1,
        max_input_length=None)
    with pytest.raises(AssertionError):
        tokenizer_pool.ping()

    with patch.dict(os.environ, {env_var: "1"}):
        tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
        tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
            tokenizer_pool_config,
zhuwenwen's avatar
zhuwenwen committed
114
            tokenizer_id=os.path.join(models_path_prefix, "gpt2"),
115
116
117
118
            enable_lora=False,
            max_num_seqs=1,
            max_input_length=None)
        tokenizer_pool.ping()
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
    """Test that Ray tokenizer pool group can recover from failures and
    if that's not possible, mark itself as unhealthy."""

    class FailingTokenizerGroup(TokenizerGroup):

        def __init__(self,
                     *args,
                     fail_at: Optional[List[int]] = None,
                     **kwargs):
            super().__init__(*args, **kwargs)
            self.i = 0
            self.fail_at = fail_at or []

        def encode(self, *args, **kwargs):
            self.i += 1
            if self.i in self.fail_at:
                sys.exit(1)
            return super().encode(*args, **kwargs)

    class FailingRayTokenizerGroupPool(RayTokenizerGroupPool):
        _worker_cls = FailingTokenizerGroup

    # Fail at first iteration
    fail_at = [1]
    tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
    tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
        tokenizer_pool_config,
zhuwenwen's avatar
zhuwenwen committed
151
        tokenizer_id=os.path.join(models_path_prefix, "gpt2"),
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        enable_lora=False,
        max_num_seqs=1,
        max_input_length=None,
        fail_at=fail_at)
    tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()

    # Modify fail at to not fail at all (will be re-read when actor is
    # re-initialized).
    fail_at[0] = 1000

    # We should recover successfully.
    await tokenizer_group_pool.encode_async(request_id="1",
                                            prompt="prompt",
                                            lora_request=None)
    await tokenizer_group_pool.encode_async(request_id="1",
                                            prompt="prompt",
                                            lora_request=None)

    # Check that we have a new actor
    assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
    assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors

    # Fail at first iteration
    fail_at = [1]
    tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
        tokenizer_pool_config,
zhuwenwen's avatar
zhuwenwen committed
178
        tokenizer_id=os.path.join(models_path_prefix, "gpt2"),
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        enable_lora=False,
        max_num_seqs=1,
        max_input_length=None,
        fail_at=fail_at)

    # We should fail after re-initialization.
    with pytest.raises(RuntimeError):
        await tokenizer_group_pool.encode_async(request_id="1",
                                                prompt="prompt",
                                                lora_request=None)

    # check_health should raise the same thing
    with pytest.raises(RuntimeError):
        tokenizer_group_pool.check_health()

    # Ensure that non-ActorDiedErrors are still propagated correctly and do not
    # cause a re-initialization.
    fail_at = []
    tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
        tokenizer_pool_config,
zhuwenwen's avatar
zhuwenwen committed
199
        tokenizer_id=os.path.join(models_path_prefix, "gpt2"),
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        enable_lora=False,
        max_num_seqs=1,
        max_input_length=2,
        fail_at=fail_at)
    tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()

    # Prompt too long error
    with pytest.raises(ValueError):
        await tokenizer_group_pool.encode_async(request_id="1",
                                                prompt="prompt" * 100,
                                                lora_request=None)
    await tokenizer_group_pool.encode_async(request_id="1",
                                            prompt="prompt",
                                            lora_request=None)
    # Actors should stay the same.
    assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors