util.py 8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
Standalone utilities for e2e_grpc tests.

This module provides all necessary utilities without depending on sglang Python package.
Extracted and adapted from:
- sglang.srt.utils.kill_process_tree
- sglang.srt.utils.hf_transformers_utils.get_tokenizer
- sglang.test.test_utils (constants and CustomTestCase)
"""

11
import logging
12
13
14
15
16
17
18
19
20
import os
import signal
import threading
import unittest
from pathlib import Path
from typing import Optional, Union

import psutil

21
22
logger = logging.getLogger(__name__)

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
try:
    from transformers import (
        AutoTokenizer,
        PreTrainedTokenizer,
        PreTrainedTokenizerBase,
        PreTrainedTokenizerFast,
    )
except ImportError:
    raise ImportError(
        "transformers is required for tokenizer utilities. "
        "Install with: pip install transformers"
    )


# ============================================================================
# Constants
# ============================================================================

# Server and timeout constants
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 20000
DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"

# File name constants for test output
STDOUT_FILENAME = "/tmp/sglang_test_stdout.txt"
STDERR_FILENAME = "/tmp/sglang_test_stderr.txt"

# Model base path - can be overridden via environment variable
# By default, use HuggingFace model identifiers (no local path prefix)
# Set ROUTER_LOCAL_MODEL_PATH to use local models (e.g., "/home/ubuntu/models")
ROUTER_LOCAL_MODEL_PATH = os.environ.get("ROUTER_LOCAL_MODEL_PATH", "")


# Helper function to build model paths
def _get_model_path(model_identifier: str) -> str:
    """
    Build model path from base path and model identifier.

    If ROUTER_LOCAL_MODEL_PATH is set, prepend it to the identifier.
    Otherwise, return the identifier as-is (for HuggingFace download).
    """
    if ROUTER_LOCAL_MODEL_PATH:
        return os.path.join(ROUTER_LOCAL_MODEL_PATH, model_identifier)
    return model_identifier


# Model paths used in e2e_grpc tests
# These can be either HuggingFace identifiers or local paths (depending on ROUTER_LOCAL_MODEL_PATH)

# Main test model - Llama 3.1 8B Instruct
DEFAULT_MODEL_PATH = _get_model_path("meta-llama/Llama-3.1-8B-Instruct")

# Small models for function calling tests
DEFAULT_SMALL_MODEL_PATH = _get_model_path("meta-llama/Llama-3.2-1B-Instruct")

# Reasoning models
DEFAULT_REASONING_MODEL_PATH = _get_model_path(
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
)

# Thinking-enabled models
DEFAULT_ENABLE_THINKING_MODEL_PATH = _get_model_path("Qwen/Qwen3-30B-A3B")

# Function calling models
DEFAULT_QWEN_FUNCTION_CALLING_MODEL_PATH = _get_model_path("Qwen/Qwen2.5-7B-Instruct")
DEFAULT_MISTRAL_FUNCTION_CALLING_MODEL_PATH = _get_model_path(
    "mistralai/Mistral-7B-Instruct-v0.3"
)


# ============================================================================
# Process Management
# ============================================================================


def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
    """
    Kill the process and all its child processes.

    Args:
        parent_pid: PID of the parent process
        include_parent: Whether to kill the parent process itself
        skip_pid: Optional PID to skip during cleanup
    """
    # Remove sigchld handler to avoid spammy logs
    if threading.current_thread() is threading.main_thread():
        signal.signal(signal.SIGCHLD, signal.SIG_DFL)

    if parent_pid is None:
        parent_pid = os.getpid()
        include_parent = False

    try:
        itself = psutil.Process(parent_pid)
    except psutil.NoSuchProcess:
        return

    children = itself.children(recursive=True)
    for child in children:
        if child.pid == skip_pid:
            continue
        try:
            child.kill()
        except psutil.NoSuchProcess:
            pass

    if include_parent:
        try:
            itself.kill()
        except psutil.NoSuchProcess:
            pass


# ============================================================================
# Tokenizer Utilities
# ============================================================================


def check_gguf_file(model_path: str) -> bool:
    """Check if the model path points to a GGUF file."""
    if not isinstance(model_path, str):
        return False
    return model_path.endswith(".gguf")


def is_remote_url(path: str) -> bool:
    """Check if the path is a remote URL."""
    if not isinstance(path, str):
        return False
    return path.startswith("http://") or path.startswith("https://")


def get_tokenizer(
    tokenizer_name: str,
    *args,
    tokenizer_mode: str = "auto",
    trust_remote_code: bool = False,
    tokenizer_revision: Optional[str] = None,
    **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    """
    Gets a tokenizer for the given model name via Huggingface.

    Args:
        tokenizer_name: Name or path of the tokenizer
        tokenizer_mode: Mode for tokenizer loading ("auto", "slow")
        trust_remote_code: Whether to trust remote code
        tokenizer_revision: Specific revision to use
        **kwargs: Additional arguments passed to AutoTokenizer.from_pretrained

    Returns:
        Loaded tokenizer instance
    """
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

    # Handle special model name mapping
    if tokenizer_name == "mistralai/Devstral-Small-2505":
        tokenizer_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"

    is_gguf = check_gguf_file(tokenizer_name)
    if is_gguf:
        kwargs["gguf_file"] = tokenizer_name
        tokenizer_name = Path(tokenizer_name).parent

    # Note: Removed remote URL handling and local directory download
    # as they depend on sglang-specific utilities

    try:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name,
            *args,
            trust_remote_code=trust_remote_code,
            tokenizer_revision=tokenizer_revision,
            **kwargs,
        )
    except TypeError as e:
        # Handle specific errors
        err_msg = (
            "Failed to load the tokenizer. If you are running a model with "
            "a custom tokenizer, please set the --trust-remote-code flag."
        )
        raise RuntimeError(err_msg) from e

    if not isinstance(tokenizer, PreTrainedTokenizerFast):
210
211
        logger.warning(
            f"Using a slow tokenizer. This might cause a performance "
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
            f"degradation. Consider using a fast tokenizer instead."
        )

    return tokenizer


def get_tokenizer_from_processor(processor):
    """Extract tokenizer from a processor object."""
    if isinstance(processor, PreTrainedTokenizerBase):
        return processor
    return processor.tokenizer


# ============================================================================
# Test Utilities
# ============================================================================


class CustomTestCase(unittest.TestCase):
    """
    Custom test case base class with retry support.

    This provides automatic test retry functionality based on environment variables.
    """

    def _callTestMethod(self, method):
        """Override to add retry logic."""
        max_retry = int(os.environ.get("SGLANG_TEST_MAX_RETRY", "0"))

        if max_retry == 0:
            # No retry, just run once
            return super(CustomTestCase, self)._callTestMethod(method)

        # Retry logic
        for attempt in range(max_retry + 1):
            try:
                return super(CustomTestCase, self)._callTestMethod(method)
            except Exception as e:
                if attempt < max_retry:
251
                    logger.info(
252
253
254
255
256
257
                        f"Test failed on attempt {attempt + 1}/{max_retry + 1}, retrying..."
                    )
                    continue
                else:
                    # Last attempt, re-raise the exception
                    raise