post_multi_servers.py 4.67 KB
Newer Older
gaclove's avatar
gaclove committed
1
2
import base64
import os
3
import threading
PengGao's avatar
PengGao committed
4
import time
5
from typing import Any
PengGao's avatar
PengGao committed
6
7
8

import requests
from loguru import logger
9
from tqdm import tqdm
10
11


gaclove's avatar
gaclove committed
12
13
14
15
16
17
18
def image_to_base64(image_path):
    """Convert an image file to base64 string"""
    with open(image_path, "rb") as f:
        image_data = f.read()
    return base64.b64encode(image_data).decode("utf-8")


19
20
def process_image_path(image_path) -> Any | str:
    """Process image_path: convert to base64 if local path, keep unchanged if HTTP link"""
gaclove's avatar
gaclove committed
21
22
23
24
25
26
27
28
29
30
31
32
33
    if not image_path:
        return image_path

    if image_path.startswith(("http://", "https://")):
        return image_path

    if os.path.exists(image_path):
        return image_to_base64(image_path)
    else:
        logger.warning(f"Image path not found: {image_path}")
        return image_path


34
35
def send_and_monitor_task(url, message, task_index, complete_bar, complete_lock):
    """Send task to server and monitor until completion"""
36
    try:
gaclove's avatar
gaclove committed
37
38
39
        if "image_path" in message and message["image_path"]:
            message["image_path"] = process_image_path(message["image_path"])

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        response = requests.post(f"{url}/v1/tasks/", json=message)
        response_data = response.json()
        task_id = response_data.get("task_id")

        if not task_id:
            logger.error(f"No task_id received from {url}")
            return False

        # Step 2: Monitor task status until completion
        while True:
            try:
                status_response = requests.get(f"{url}/v1/tasks/{task_id}/status")
                status_data = status_response.json()
                task_status = status_data.get("status")

                if task_status == "completed":
                    # Update completion bar safely
                    if complete_bar and complete_lock:
                        with complete_lock:
                            complete_bar.update(1)
                    return True
                elif task_status == "failed":
                    logger.error(f"Task {task_index + 1} (task_id: {task_id}) failed")
                    if complete_bar and complete_lock:
                        with complete_lock:
                            complete_bar.update(1)  # Still update progress even if failed
                    return False
                else:
                    time.sleep(0.5)

            except Exception as e:
                logger.error(f"Failed to check status for task_id {task_id}: {e}")
                time.sleep(0.5)

    except Exception as e:
        logger.error(f"Failed to send task to {url}: {e}")
        return False


def get_available_urls(urls):
    """Check which URLs are available and return the list"""
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
81
82
83
84
    available_urls = []
    for url in urls:
        try:
            _ = requests.get(f"{url}/v1/service/status").json()
85
            available_urls.append(url)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
86
87
88
89
90
        except Exception as e:
            continue

    if not available_urls:
        logger.error("No available urls.")
91
        return None
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
92
93

    logger.info(f"available_urls: {available_urls}")
94
95
    return available_urls

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
96

97
98
def find_idle_server(available_urls):
    """Find an idle server from available URLs"""
99
    while True:
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
100
        for url in available_urls:
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
            try:
                response = requests.get(f"{url}/v1/service/status").json()
                if response["service_status"] == "idle":
                    return url
            except Exception as e:
                continue
        time.sleep(3)


def process_tasks_async(messages, available_urls, show_progress=True):
    """Process a list of tasks asynchronously across multiple servers"""
    if not available_urls:
        logger.error("No available servers to process tasks.")
        return False

    active_threads = []

    logger.info(f"Sending {len(messages)} tasks to available servers...")

gaclove's avatar
gaclove committed
120
121
    complete_bar = None
    complete_lock = None
122
123
124
125
126
127
128
129
130
    if show_progress:
        complete_bar = tqdm(total=len(messages), desc="Completing tasks")
        complete_lock = threading.Lock()  # Thread-safe updates to completion bar

    for idx, message in enumerate(messages):
        # Find an idle server
        server_url = find_idle_server(available_urls)

        # Create and start thread for sending and monitoring task
gaclove's avatar
gaclove committed
131
        thread = threading.Thread(target=send_and_monitor_task, args=(server_url, message, idx, complete_bar, complete_lock))
132
133
134
135
136
137
138
139
140
141
142
143
        thread.daemon = False
        thread.start()
        active_threads.append(thread)

        # Small delay to let thread start
        time.sleep(0.5)

    # Wait for all threads to complete
    for thread in active_threads:
        thread.join()

    # Close completion bar
gaclove's avatar
gaclove committed
144
    if complete_bar:
145
146
147
148
        complete_bar.close()

    logger.info("All tasks processing completed!")
    return True