test_tokenizer_group.py 7.63 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
import os
5
import sys
6
from typing import Optional
7
8
from unittest.mock import patch

9
import pytest
10
from transformers import AutoTokenizer, PreTrainedTokenizerBase
11

12
13
from vllm.transformers_utils.tokenizer_group import (TokenizerGroup,
                                                     get_tokenizer_group)
14
15
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
    RayTokenizerGroupPool)
16

17
from ..conftest import get_tokenizer_pool_config
18
from ..utils import models_path_prefix
19
20


21
22
23
24
25
26
27
28
29
30
31
class CustomTokenizerGroup(TokenizerGroup):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._i = 0

    def encode(self, *args, **kwargs):
        self._i += 1
        return super().encode(*args, **kwargs)


32
@pytest.mark.asyncio
33
34
@pytest.mark.parametrize("tokenizer_group_type",
                         [None, "ray", CustomTokenizerGroup])
35
async def test_tokenizer_group(tokenizer_group_type):
36
    reference_tokenizer = AutoTokenizer.from_pretrained(os.path.join(models_path_prefix, "gpt2"))
37
38
    tokenizer_group = get_tokenizer_group(
        get_tokenizer_pool_config(tokenizer_group_type),
zhuwenwen's avatar
zhuwenwen committed
39
        tokenizer_id=os.path.join(models_path_prefix, "gpt2"),
40
41
42
43
44
        enable_lora=False,
        max_num_seqs=1,
        max_input_length=None,
    )
    assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
45
        prompt="prompt", lora_request=None)
46
    assert reference_tokenizer.encode(
47
48
        "prompt") == await tokenizer_group.encode_async(prompt="prompt",
                                                        lora_request=None)
49
50
51
52
    assert isinstance(tokenizer_group.get_lora_tokenizer(None),
                      PreTrainedTokenizerBase)
    assert tokenizer_group.get_lora_tokenizer(
        None) == await tokenizer_group.get_lora_tokenizer_async(None)
53
54
    if tokenizer_group_type is CustomTokenizerGroup:
        assert tokenizer_group._i > 0
55
56
57
58
59


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_pool(tokenizer_group_type):
60
    reference_tokenizer = AutoTokenizer.from_pretrained(os.path.join(models_path_prefix, "gpt2"))
61
62
    tokenizer_group_pool = get_tokenizer_group(
        get_tokenizer_pool_config(tokenizer_group_type),
zhuwenwen's avatar
zhuwenwen committed
63
        tokenizer_id=os.path.join(models_path_prefix, "gpt2"),
64
65
66
67
68
69
70
71
72
        enable_lora=False,
        max_num_seqs=1,
        max_input_length=None,
    )
    # Send multiple requests to the tokenizer group pool
    # (more than the pool size)
    # and check that all requests are processed correctly.
    num_requests = tokenizer_group_pool.pool_size * 5
    requests = [
73
        tokenizer_group_pool.encode_async(prompt=f"prompt {i}",
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
                                          lora_request=None)
        for i in range(num_requests)
    ]
    results = await asyncio.gather(*requests)
    expected_results = [
        reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests)
    ]
    assert results == expected_results


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_ray_pool_env_var_propagation(
        tokenizer_group_type):
    """Test that env vars from caller process are propagated to
    tokenizer Ray actors."""
    env_var = "MY_ENV_VAR"

    class EnvVarCheckerTokenizerGroup(TokenizerGroup):

        def ping(self):
            assert os.environ.get(env_var) == "1"
            return super().ping()

    class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool):
        _worker_cls = EnvVarCheckerTokenizerGroup

    tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
    tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
        tokenizer_pool_config,
zhuwenwen's avatar
zhuwenwen committed
104
        tokenizer_id=os.path.join(models_path_prefix, "gpt2"),
105
106
107
108
109
110
111
112
113
114
        enable_lora=False,
        max_num_seqs=1,
        max_input_length=None)
    with pytest.raises(AssertionError):
        tokenizer_pool.ping()

    with patch.dict(os.environ, {env_var: "1"}):
        tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
        tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
            tokenizer_pool_config,
zhuwenwen's avatar
zhuwenwen committed
115
            tokenizer_id=os.path.join(models_path_prefix, "gpt2"),
116
117
118
119
            enable_lora=False,
            max_num_seqs=1,
            max_input_length=None)
        tokenizer_pool.ping()
120
121
122
123
124
125
126
127
128
129
130
131


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
    """Test that Ray tokenizer pool group can recover from failures and
    if that's not possible, mark itself as unhealthy."""

    class FailingTokenizerGroup(TokenizerGroup):

        def __init__(self,
                     *args,
132
                     fail_at: Optional[list[int]] = None,
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
                     **kwargs):
            super().__init__(*args, **kwargs)
            self.i = 0
            self.fail_at = fail_at or []

        def encode(self, *args, **kwargs):
            self.i += 1
            if self.i in self.fail_at:
                sys.exit(1)
            return super().encode(*args, **kwargs)

    class FailingRayTokenizerGroupPool(RayTokenizerGroupPool):
        _worker_cls = FailingTokenizerGroup

    # Fail at first iteration
    fail_at = [1]
    tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
    tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
        tokenizer_pool_config,
zhuwenwen's avatar
zhuwenwen committed
152
        tokenizer_id=os.path.join(models_path_prefix, "gpt2"),
153
154
155
156
157
158
159
160
161
162
163
        enable_lora=False,
        max_num_seqs=1,
        max_input_length=None,
        fail_at=fail_at)
    tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()

    # Modify fail at to not fail at all (will be re-read when actor is
    # re-initialized).
    fail_at[0] = 1000

    # We should recover successfully.
164
165
    await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
    await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
166
167
168
169
170
171
172
173
174

    # Check that we have a new actor
    assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
    assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors

    # Fail at first iteration
    fail_at = [1]
    tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
        tokenizer_pool_config,
zhuwenwen's avatar
zhuwenwen committed
175
        tokenizer_id=os.path.join(models_path_prefix, "gpt2"),
176
177
178
179
180
181
182
        enable_lora=False,
        max_num_seqs=1,
        max_input_length=None,
        fail_at=fail_at)

    # We should fail after re-initialization.
    with pytest.raises(RuntimeError):
183
        await tokenizer_group_pool.encode_async(prompt="prompt",
184
185
186
187
188
189
190
191
192
193
194
                                                lora_request=None)

    # check_health should raise the same thing
    with pytest.raises(RuntimeError):
        tokenizer_group_pool.check_health()

    # Ensure that non-ActorDiedErrors are still propagated correctly and do not
    # cause a re-initialization.
    fail_at = []
    tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
        tokenizer_pool_config,
zhuwenwen's avatar
zhuwenwen committed
195
        tokenizer_id=os.path.join(models_path_prefix, "gpt2"),
196
197
198
199
200
201
202
203
        enable_lora=False,
        max_num_seqs=1,
        max_input_length=2,
        fail_at=fail_at)
    tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()

    # Prompt too long error
    with pytest.raises(ValueError):
204
        await tokenizer_group_pool.encode_async(prompt="prompt" * 100,
205
                                                lora_request=None)
206
    await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
207
208
    # Actors should stay the same.
    assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors