api.py 4.45 KB
Newer Older
wuxk1's avatar
wuxk1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import comfy.utils
from ..libs.api.fluxai import fluxaiAPI
from ..libs.api.bizyair import bizyairAPI, encode_data
from nodes import NODE_CLASS_MAPPINGS as ALL_NODE_CLASS_MAPPINGS

class joyCaption2API:
    API_URL = f"/supernode/joycaption2"

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "do_sample": ([True, False],),
                "temperature": (
                    "FLOAT",
                    {
                        "default": 0.5,
                        "min": 0.0,
                        "max": 2.0,
                        "step": 0.01,
                        "round": 0.001,
                        "display": "number",
                    },
                ),
                "max_tokens": (
                    "INT",
                    {
                        "default": 256,
                        "min": 16,
                        "max": 512,
                        "step": 16,
                        "display": "number",
                    },
                ),
                "caption_type": (
                    [
                        "Descriptive",
                        "Descriptive (Informal)",
                        "Training Prompt",
                        "MidJourney",
                        "Booru tag list",
                        "Booru-like tag list",
                        "Art Critic",
                        "Product Listing",
                        "Social Media Post",
                    ],
                ),
                "caption_length": (
                    ["any", "very short", "short", "medium-length", "long", "very long"]
                    + [str(i) for i in range(20, 261, 10)],
                ),
                "extra_options": (
                    "STRING",
                    {
                        "placeholder": "Extra options(e.g):\nIf there is a person/character in the image you must refer to them as {name}.",
                        "tooltip": "Extra options for the model",
                        "multiline": True,
                    },
                ),
                "name_input": (
                    "STRING",
                    {
                        "default": "",
                        "tooltip": "Name input is only used if an Extra Option is selected that requires it.",
                    },
                ),
                "custom_prompt": (
                    "STRING",
                    {
                        "default": "",
                        "multiline": True,
                    },
                ),
            },
            "optional":{
                "apikey_override": ("STRING", {"default": "", "forceInput": True, "tooltip":"Override the API key in the local config"}),
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("caption",)

    FUNCTION = "joycaption"
    OUTPUT_NODE = False

    CATEGORY = "EasyUse/API"

    def joycaption(
            self,
            image,
            do_sample,
            temperature,
            max_tokens,
            caption_type,
            caption_length,
            extra_options,
            name_input,
            custom_prompt,
            apikey_override=None
    ):
        pbar = comfy.utils.ProgressBar(100)
        pbar.update_absolute(10)
        SIZE_LIMIT = 1536
        _, w, h, c = image.shape
        if w > SIZE_LIMIT or h > SIZE_LIMIT:
            node_class = ALL_NODE_CLASS_MAPPINGS['easy imageScaleDownToSize']
            image, = node_class().image_scale_down_to_size(image, SIZE_LIMIT, True)

        payload = {
            "image": None,
            "do_sample": do_sample == True,
            "temperature": temperature,
            "max_new_tokens": max_tokens,
            "caption_type": caption_type,
            "caption_length": caption_length,
            "extra_options": [extra_options],
            "name_input": name_input,
            "custom_prompt": custom_prompt,
        }

        pbar.update_absolute(30)
        caption = bizyairAPI.joyCaption(payload, image, apikey_override, API_URL=self.API_URL)

        pbar.update_absolute(100)
        return (caption,)

class joyCaption3API(joyCaption2API):
    API_URL = f"/supernode/joycaption3"

NODE_CLASS_MAPPINGS = {
    "easy joyCaption2API": joyCaption2API,
    "easy joyCaption3API": joyCaption3API,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "easy joyCaption2API": "JoyCaption2 (BizyAIR)",
    "easy joyCaption3API": "JoyCaption3 (BizyAIR)",
}