"vscode:/vscode.git/clone" did not exist on "13662fd5336fc8428e130567fdb1695d664eea24"
api_multi_server.py 5.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import argparse
import subprocess
import time
import socket
import os
from typing import List, Optional, Dict
import psutil
import requests
from loguru import logger
import concurrent.futures
from dataclasses import dataclass


@dataclass
class ServerConfig:
    port: int
    gpu_id: int
    model_cls: str
    task: str
    model_path: str
    config_json: str


def get_node_ip() -> str:
    """Get the IP address of the current node"""
    try:
        # Create a UDP socket
        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        # Connect to an external address (no actual connection needed)
        s.connect(("8.8.8.8", 80))
        # Get local IP
        ip = s.getsockname()[0]
        s.close()
        return ip
    except Exception as e:
        logger.error(f"Failed to get IP address: {e}")
        return "localhost"


def is_port_in_use(port: int) -> bool:
    """Check if a port is in use"""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        return s.connect_ex(("localhost", port)) == 0


def find_available_port(start_port: int) -> Optional[int]:
    """Find an available port starting from start_port"""
    port = start_port
    while port < start_port + 1000:  # Try up to 1000 ports
        if not is_port_in_use(port):
            return port
        port += 1
    return None


def start_server(config: ServerConfig) -> Optional[tuple[subprocess.Popen, str]]:
    """Start a single server instance"""
    try:
        # Set GPU
        env = os.environ.copy()
        env["CUDA_VISIBLE_DEVICES"] = str(config.gpu_id)

        # Start server
        process = subprocess.Popen(
            [
                "python",
                "-m",
                "lightx2v.api_server",
                "--model_cls",
                config.model_cls,
                "--task",
                config.task,
                "--model_path",
                config.model_path,
                "--config_json",
                config.config_json,
                "--port",
                str(config.port),
            ],
            env=env,
        )

        # Wait for server to start, up to 600 seconds
        node_ip = get_node_ip()
        service_url = f"http://{node_ip}:{config.port}/v1/local/video/generate/service_status"

        # Check once per second, up to 600 times
        for _ in range(600):
            try:
                response = requests.get(service_url, timeout=1)
                if response.status_code == 200:
                    return process, f"http://{node_ip}:{config.port}"
            except (requests.RequestException, ConnectionError) as e:
                pass
            time.sleep(1)

        # If timeout, terminate the process
        logger.error(f"Server startup timeout: port={config.port}, gpu={config.gpu_id}")
        process.terminate()
        return None

    except Exception as e:
        logger.error(f"Failed to start server: {e}")
        return None


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_gpus", type=int, required=True, help="Number of GPUs to use")
    parser.add_argument("--start_port", type=int, required=True, help="Starting port number")
    parser.add_argument("--model_cls", type=str, required=True, help="Model class")
    parser.add_argument("--task", type=str, required=True, help="Task type")
    parser.add_argument("--model_path", type=str, required=True, help="Model path")
    parser.add_argument("--config_json", type=str, required=True, help="Config file path")
    args = parser.parse_args()

    # Prepare configurations for all servers on this node
    server_configs = []
    current_port = args.start_port

    # Create configs for each GPU on this node
    for gpu in range(args.num_gpus):
        port = find_available_port(current_port)
        if port is None:
            logger.error(f"Cannot find available port starting from {current_port}")
            continue

        config = ServerConfig(port=port, gpu_id=gpu, model_cls=args.model_cls, task=args.task, model_path=args.model_path, config_json=args.config_json)
        server_configs.append(config)
        current_port = port + 1

    # Start all servers in parallel
    processes = []
    urls = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=len(server_configs)) as executor:
        future_to_config = {executor.submit(start_server, config): config for config in server_configs}
        for future in concurrent.futures.as_completed(future_to_config):
            config = future_to_config[future]
            try:
                result = future.result()
                if result:
                    process, url = result
                    processes.append(process)
                    urls.append(url)
                    logger.info(f"Server started successfully: {url} (GPU: {config.gpu_id})")
                else:
                    logger.error(f"Failed to start server: port={config.port}, gpu={config.gpu_id}")
            except Exception as e:
                logger.error(f"Error occurred while starting server: {e}")

    # Print all server URLs
    print("\nAll server URLs:")
    for url in urls:
        print(url)

    # Print node information
    node_ip = get_node_ip()
    print(f"\nCurrent node IP: {node_ip}")
    print(f"Number of servers started: {len(urls)}")

    try:
        # Wait for all processes
        for process in processes:
            process.wait()
    except KeyboardInterrupt:
        logger.info("Received interrupt signal, shutting down all servers...")
        for process in processes:
            process.terminate()


if __name__ == "__main__":
    main()