chooser.py 5.82 KB
Newer Older
wuxk1's avatar
wuxk1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from threading import Event

import torch

from server import PromptServer
from aiohttp import web
from comfy import model_management as mm
from comfy_execution.graph import ExecutionBlocker
import time

class ChooserCancelled(Exception):
    pass

def get_chooser_cache():
    """获取选择器缓存"""
    if not hasattr(PromptServer.instance, '_easyuse_chooser_node'):
        PromptServer.instance._easyuse_chooser_node = {}
    return PromptServer.instance._easyuse_chooser_node

def cleanup_session_data(node_id):
    """清理会话数据"""
    node_data = get_chooser_cache()
    if node_id in node_data:
        session_keys = ["event", "selected", "images", "total_count", "cancelled"]
        for key in session_keys:
            if key in node_data[node_id]:
                del node_data[node_id][key]

def wait_for_chooser(id, images, mode, period=0.1):
    try:
        node_data = get_chooser_cache()
        images = [images[i:i + 1, ...] for i in range(images.shape[0])]
        if mode == "Keep Last Selection":
            if id in node_data and "last_selection" in node_data[id]:
                last_selection = node_data[id]["last_selection"]
                if last_selection and len(last_selection) > 0:
                    valid_indices = [idx for idx in last_selection if 0 <= idx < len(images)]
                    if valid_indices:
                        try:
                            PromptServer.instance.send_sync("easyuse-image-keep-selection", {
                                "id": id,
                                "selected": valid_indices
                            })
                        except Exception as e:
                            pass
                        cleanup_session_data(id)
                        indices_str = ','.join(str(i) for i in valid_indices)
                        images = [images[idx] for idx in valid_indices]
                        images = torch.cat(images, dim=0)
                        return {"result": (images,)}

        if id in node_data:
            del node_data[id]

        event = Event()
        node_data[id] = {
            "event": event,
            "images": images,
            "selected": None,
            "total_count": len(images),
            "cancelled": False,
        }

        while id in node_data:
            node_info = node_data[id]
            if node_info.get("cancelled", False):
                cleanup_session_data(id)
                raise ChooserCancelled("Manual selection cancelled")

            if "selected" in node_info and node_info["selected"] is not None:
                break

            time.sleep(period)

        if id in node_data:
            node_info = node_data[id]
            selected_indices = node_info.get("selected")

            if selected_indices is not None and len(selected_indices) > 0:
                valid_indices = [idx for idx in selected_indices if 0 <= idx < len(images)]
                if valid_indices:
                    selected_images = [images[idx] for idx in valid_indices]

                    if id not in node_data:
                        node_data[id] = {}
                    node_data[id]["last_selection"] = valid_indices
                    cleanup_session_data(id)
                    selected_images = torch.cat(selected_images, dim=0)
                    return {"result": (selected_images,)}
                else:
                    cleanup_session_data(id)
                    return {"result": (images[0] if len(images) > 0 else ExecutionBlocker(None),)}
            else:
                cleanup_session_data(id)
                return {
                    "result": (images[0] if len(images) > 0 else ExecutionBlocker(None),)}
        else:
            return {"result": (images[0] if len(images) > 0 else ExecutionBlocker(None),)}

    except ChooserCancelled:
        raise mm.InterruptProcessingException()
    except Exception as e:
        node_data = get_chooser_cache()
        if id in node_data:
            cleanup_session_data(id)
        if 'image_list' in locals() and len(images) > 0:
            return {"result": (images[0])}
        else:
            return {"result": (ExecutionBlocker(None),)}


@PromptServer.instance.routes.post('/easyuse/image_chooser_message')
async def handle_image_selection(request):
    try:
        data = await request.json()
        node_id = data.get("node_id")
        selected = data.get("selected", [])
        action = data.get("action")

        node_data = get_chooser_cache()

        if node_id not in node_data:
            return web.json_response({"code": -1, "error": "Node data does not exist"})

        try:
            node_info = node_data[node_id]

            if "total_count" not in node_info:
                return web.json_response({"code": -1, "error": "The node has been processed"})

            if action == "cancel":
                node_info["cancelled"] = True
                node_info["selected"] = []
            elif action == "select" and isinstance(selected, list):
                valid_indices = [idx for idx in selected if isinstance(idx, int) and 0 <= idx < node_info["total_count"]]
                if valid_indices:
                    node_info["selected"] = valid_indices
                    node_info["cancelled"] = False
                else:
                    return web.json_response({"code": -1, "error": "Invalid Selection Index"})
            else:
                return web.json_response({"code": -1, "error": "Invalid operation"})

            node_info["event"].set()
            return web.json_response({"code": 1})

        except Exception as e:
            if node_id in node_data and "event" in node_data[node_id]:
                node_data[node_id]["event"].set()
            return web.json_response({"code": -1, "message": "Processing Failed"})

    except Exception as e:
        return web.json_response({"code": -1, "message": "Request Failed"})