test_registry.py 2.11 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from pathlib import Path

5
6
7
8
9
10
11
12
import pytest

from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.registry import (
    TokenizerRegistry,
    get_tokenizer,
    resolve_tokenizer_args,
)
13
14
15
16


class TestTokenizer(TokenizerLike):
    @classmethod
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    def from_pretrained(
        cls,
        path_or_repo_id: str | Path,
        *args,
        trust_remote_code: bool = False,
        revision: str | None = None,
        download_dir: str | None = None,
        **kwargs,
    ) -> "TestTokenizer":
        return TestTokenizer(path_or_repo_id)  # type: ignore

    def __init__(self, path_or_repo_id: str | Path) -> None:
        super().__init__()

        self.path_or_repo_id = path_or_repo_id
32
33
34
35
36
37
38
39
40

    @property
    def bos_token_id(self) -> int:
        return 0

    @property
    def eos_token_id(self) -> int:
        return 1

41
42
43
44
45
46
47
48
    @property
    def pad_token_id(self) -> int:
        return 2

    @property
    def is_fast(self) -> bool:
        return True

49

50
51
52
53
54
55
56
57
58
59
60
61
@pytest.mark.parametrize("runner_type", ["generate", "pooling"])
def test_resolve_tokenizer_args_idempotent(runner_type):
    tokenizer_mode, tokenizer_name, args, kwargs = resolve_tokenizer_args(
        "facebook/opt-125m",
        runner_type=runner_type,
    )

    assert (tokenizer_mode, tokenizer_name, args, kwargs) == resolve_tokenizer_args(
        tokenizer_name, *args, **kwargs
    )


62
def test_customized_tokenizer():
63
    TokenizerRegistry.register("test_tokenizer", __name__, TestTokenizer.__name__)
64

65
    tokenizer = TokenizerRegistry.load_tokenizer("test_tokenizer", "abc")
66
    assert isinstance(tokenizer, TestTokenizer)
67
    assert tokenizer.path_or_repo_id == "abc"
68
69
    assert tokenizer.bos_token_id == 0
    assert tokenizer.eos_token_id == 1
70
    assert tokenizer.pad_token_id == 2
71

72
    tokenizer = get_tokenizer("abc", tokenizer_mode="test_tokenizer")
73
    assert isinstance(tokenizer, TestTokenizer)
74
    assert tokenizer.path_or_repo_id == "abc"
75
76
    assert tokenizer.bos_token_id == 0
    assert tokenizer.eos_token_id == 1
77
    assert tokenizer.pad_token_id == 2