test_skip_tokenizer_init.py 3.98 KB
Newer Older
1
2
3
4
import json
import unittest

import requests
5
from transformers import AutoTokenizer
6

7
from sglang.srt.utils import kill_process_tree
8
from sglang.test.test_utils import (
Lianmin Zheng's avatar
Lianmin Zheng committed
9
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
10
11
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
12
13
    popen_launch_server,
)
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
_server_process = None
_base_url = None
_tokenizer = None


def setUpModule():
    """
    Launch the server once before all tests and initialize the tokenizer.
    """
    global _server_process, _base_url, _tokenizer
    _server_process = popen_launch_server(
        DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
        DEFAULT_URL_FOR_TEST,
        timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        other_args=["--skip-tokenizer-init"],
    )
    _base_url = DEFAULT_URL_FOR_TEST

    _tokenizer = AutoTokenizer.from_pretrained(
        DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
    )
    print(">>> setUpModule: Server launched, tokenizer ready")


def tearDownModule():
    """
    Terminate the server once after all tests have completed.
    """
    global _server_process
    if _server_process is not None:
        kill_process_tree(_server_process.pid)
        _server_process = None
    print(">>> tearDownModule: Server terminated")
48
49


50
51
52
53
54
55
56
57
58
59
60
61
class TestSkipTokenizerInit(unittest.TestCase):
    def run_decode(
        self,
        prompt_text="The capital of France is",
        max_new_tokens=32,
        return_logprob=False,
        top_logprobs_num=0,
        n=1,
    ):
        input_ids = _tokenizer(prompt_text, return_tensors="pt")["input_ids"][
            0
        ].tolist()
62
63

        response = requests.post(
64
            _base_url + "/generate",
65
            json={
66
                "input_ids": input_ids,
67
68
                "sampling_params": {
                    "temperature": 0 if n == 1 else 0.5,
69
                    "max_new_tokens": max_new_tokens,
70
                    "n": n,
71
                    "stop_token_ids": [_tokenizer.eos_token_id],
72
73
74
75
76
77
78
                },
                "stream": False,
                "return_logprob": return_logprob,
                "top_logprobs_num": top_logprobs_num,
                "logprob_start_len": 0,
            },
        )
79
        ret = response.json()
80
        print(json.dumps(ret, indent=2))
81
82

        def assert_one_item(item):
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            if item["meta_info"]["finish_reason"]["type"] == "stop":
                self.assertEqual(
                    item["meta_info"]["finish_reason"]["matched"],
                    _tokenizer.eos_token_id,
                )
            elif item["meta_info"]["finish_reason"]["type"] == "length":
                self.assertEqual(
                    len(item["token_ids"]), item["meta_info"]["completion_tokens"]
                )
                self.assertEqual(len(item["token_ids"]), max_new_tokens)
                self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))

                if return_logprob:
                    self.assertEqual(
                        len(item["meta_info"]["input_token_logprobs"]),
                        len(input_ids),
                        f'{len(item["meta_info"]["input_token_logprobs"])} mismatch with {len(input_ids)}',
                    )
                    self.assertEqual(
                        len(item["meta_info"]["output_token_logprobs"]),
                        max_new_tokens,
                    )

        # Determine whether to assert a single item or multiple items based on n
107
108
109
        if n == 1:
            assert_one_item(ret)
        else:
110
            self.assertEqual(len(ret), n)
111
112
113
            for i in range(n):
                assert_one_item(ret[i])

114
115
116
117
118
119
120
121
122
123
        print("=" * 100)

    def test_simple_decode(self):
        self.run_decode()

    def test_parallel_sample(self):
        self.run_decode(n=3)

    def test_logprob(self):
        for top_logprobs_num in [0, 3]:
124
125
126
127
            self.run_decode(return_logprob=True, top_logprobs_num=top_logprobs_num)

    def test_eos_behavior(self):
        self.run_decode(max_new_tokens=256)
128
129
130


if __name__ == "__main__":
Lianmin Zheng's avatar
Lianmin Zheng committed
131
    unittest.main()