test_engine_token_ids.py 1.44 KB
Newer Older
1
2
3
4
5
import unittest

from transformers import AutoTokenizer

import sglang as sgl
6
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
7
8
9
10
11


class TestEngineTokenIds(unittest.TestCase):
    def test_token_ids_in_generate(self):
        llm = sgl.Engine(
12
            model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, return_token_ids=True
13
        )
14
        tokenizer = AutoTokenizer.from_pretrained(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
15
16
17
18
19
20
21

        prompts = [
            "Hello, my name is",
            "The president of the United States is",
            "The capital of France is",
            "The future of AI is",
        ]
22

23
        sampling_params = {"temperature": 0, "top_p": 0.95}
24
25
26
        outputs = llm.generate(prompts, sampling_params)

        for prompt, output in zip(prompts, outputs):
27
28
29
30
31
            deocode_input = tokenizer.decode(
                output["input_ids"], skip_special_tokens=True
            )
            assert (deocode_input in prompt) or (
                prompt in deocode_input
32
33
            ), f"Decode input: {deocode_input} mismatch for: {prompt}"

34
35
36
37
38
            deocode_output = tokenizer.decode(
                output["output_ids"], skip_special_tokens=True
            )
            assert (deocode_output in output["text"]) or (
                output["text"] in deocode_output
39
            ), f"Decode output: {deocode_output} mismatch for: {output['text']}"
40
41
42
43
44
45

        llm.shutdown()


if __name__ == "__main__":
    unittest.main()