test_input_embeddings.py 5.68 KB
Newer Older
Rin Intachuen's avatar
Rin Intachuen committed
1
import json
2
3
import os
import tempfile
Rin Intachuen's avatar
Rin Intachuen committed
4
5
6
7
8
import unittest

import requests
from transformers import AutoModelForCausalLM, AutoTokenizer

9
from sglang.srt.utils import kill_process_tree
Rin Intachuen's avatar
Rin Intachuen committed
10
11
12
13
from sglang.test.test_utils import (
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
14
    CustomTestCase,
Rin Intachuen's avatar
Rin Intachuen committed
15
16
17
18
    popen_launch_server,
)


19
class TestInputEmbeds(CustomTestCase):
Rin Intachuen's avatar
Rin Intachuen committed
20
21
22
23
24
25
26
27
28
29
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.tokenizer = AutoTokenizer.from_pretrained(cls.model)
        cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model)
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
30
            other_args=["--disable-radix", "--cuda-graph-max-bs", 4],
Rin Intachuen's avatar
Rin Intachuen committed
31
32
33
34
35
36
37
38
39
40
41
42
43
        )
        cls.texts = [
            "The capital of France is",
            "What is the best time of year to visit Japan for cherry blossoms?",
        ]

    def generate_input_embeddings(self, text):
        """Generate input embeddings for a given text."""
        input_ids = self.tokenizer(text, return_tensors="pt")["input_ids"]
        embeddings = self.ref_model.get_input_embeddings()(input_ids)
        return embeddings.squeeze().tolist()  # Convert tensor to a list for API use

    def send_request(self, payload):
44
        """Send a POST request to the /generate endpoint and return the response."""
Rin Intachuen's avatar
Rin Intachuen committed
45
46
47
48
49
50
51
52
53
54
55
        response = requests.post(
            self.base_url + "/generate",
            json=payload,
            timeout=30,  # Set a reasonable timeout for the API request
        )
        if response.status_code == 200:
            return response.json()
        return {
            "error": f"Request failed with status {response.status_code}: {response.text}"
        }

56
57
58
59
60
61
62
63
64
65
66
67
68
69
    def send_file_request(self, file_path):
        """Send a POST request to the /generate_from_file endpoint with a file."""
        with open(file_path, "rb") as f:
            response = requests.post(
                self.base_url + "/generate_from_file",
                files={"file": f},
                timeout=30,  # Set a reasonable timeout for the API request
            )
        if response.status_code == 200:
            return response.json()
        return {
            "error": f"Request failed with status {response.status_code}: {response.text}"
        }

Rin Intachuen's avatar
Rin Intachuen committed
70
    def test_text_based_response(self):
71
        """Test and print API responses using text-based input."""
Rin Intachuen's avatar
Rin Intachuen committed
72
73
74
75
76
77
78
79
80
81
82
83
        for text in self.texts:
            payload = {
                "model": self.model,
                "text": text,
                "sampling_params": {"temperature": 0, "max_new_tokens": 50},
            }
            response = self.send_request(payload)
            print(
                f"Text Input: {text}\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
            )

    def test_embedding_based_response(self):
84
        """Test and print API responses using input embeddings."""
Rin Intachuen's avatar
Rin Intachuen committed
85
86
87
88
89
90
91
92
93
94
95
96
97
        for text in self.texts:
            embeddings = self.generate_input_embeddings(text)
            payload = {
                "model": self.model,
                "input_embeds": embeddings,
                "sampling_params": {"temperature": 0, "max_new_tokens": 50},
            }
            response = self.send_request(payload)
            print(
                f"Embeddings Input (for text '{text}'):\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
            )

    def test_compare_text_vs_embedding(self):
98
        """Test and compare responses for text-based and embedding-based inputs."""
Rin Intachuen's avatar
Rin Intachuen committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        for text in self.texts:
            # Text-based payload
            text_payload = {
                "model": self.model,
                "text": text,
                "sampling_params": {"temperature": 0, "max_new_tokens": 50},
            }
            # Embedding-based payload
            embeddings = self.generate_input_embeddings(text)
            embed_payload = {
                "model": self.model,
                "input_embeds": embeddings,
                "sampling_params": {"temperature": 0, "max_new_tokens": 50},
            }
            # Get responses
            text_response = self.send_request(text_payload)
            embed_response = self.send_request(embed_payload)
            # Print responses
            print(
                f"Text Input: {text}\nText-Based Response: {json.dumps(text_response, indent=2)}\n"
            )
            print(
                f"Embeddings Input (for text '{text}'):\nEmbedding-Based Response: {json.dumps(embed_response, indent=2)}\n{'-' * 80}"
            )
123
124
            # This is flaky, so we skip this temporarily
            # self.assertEqual(text_response["text"], embed_response["text"])
Rin Intachuen's avatar
Rin Intachuen committed
125

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    def test_generate_from_file(self):
        """Test the /generate_from_file endpoint using tokenized embeddings."""
        for text in self.texts:
            embeddings = self.generate_input_embeddings(text)
            with tempfile.NamedTemporaryFile(
                mode="w", suffix=".json", delete=False
            ) as tmp_file:
                json.dump(embeddings, tmp_file)
                tmp_file_path = tmp_file.name

            try:
                response = self.send_file_request(tmp_file_path)
                print(
                    f"Text Input: {text}\nResponse from /generate_from_file: {json.dumps(response, indent=2)}\n{'-' * 80}"
                )
            finally:
                # Ensure the temporary file is deleted
                os.remove(tmp_file_path)

Rin Intachuen's avatar
Rin Intachuen committed
145
146
    @classmethod
    def tearDownClass(cls):
147
        kill_process_tree(cls.process.pid)
Rin Intachuen's avatar
Rin Intachuen committed
148
149
150
151


if __name__ == "__main__":
    unittest.main()