pulid.py 12.9 KB
Newer Older
wuxk1's avatar
wuxk1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
"""
This module provides nodes and utilities for integrating the Nunchaku PuLID pipeline
with ComfyUI, enabling face restoration and enhancement using PuLID and related models.

.. note::

    Adapted from: https://github.com/lldacing/ComfyUI_PuLID_Flux_ll
"""

import copy
import logging
import os
from functools import partial
from types import MethodType

import comfy
import folder_paths
import numpy as np
import torch

from nunchaku.models.pulid.pulid_forward import pulid_forward
from nunchaku.pipeline.pipeline_flux_pulid import PuLIDPipeline

from ...wrappers.flux import ComfyFluxWrapper
from .utils import set_extra_config_model_path

# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()

# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


set_extra_config_model_path("pulid", "pulid")
set_extra_config_model_path("insightface", "insightface")
set_extra_config_model_path("facexlib", "facexlib")


class NunchakuFluxPuLIDApplyV2:
    """
    Node for applying PuLID to a Nunchaku FLUX model.
    """

    @classmethod
    def INPUT_TYPES(s):
        """
        Defines the input types and tooltips for the node.

        Returns
        -------
        dict
            A dictionary specifying the required inputs and their descriptions for the node interface.
        """
        return {
            "required": {
                "model": ("MODEL",),
                "pulid_pipline": ("PULID_PIPELINE",),
                "image": ("IMAGE",),
                "weight": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 5.0, "step": 0.05}),
                "start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
                "end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
            },
            "optional": {
                "attn_mask": ("MASK",),
                "options": ("OPTIONS",),
            },
            "hidden": {"unique_id": "UNIQUE_ID"},
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "apply"
    CATEGORY = "Nunchaku"
    TITLE = "Nunchaku FLUX PuLID Apply V2"

    def apply(
        self,
        model,
        pulid_pipline: PuLIDPipeline,
        image,
        weight: float,
        start_at: float,
        end_at: float,
        attn_mask=None,
        options=None,
        unique_id=None,
    ):
        """
        Apply PuLID ID customization according to the given image to the model.

        Parameters
        ----------
        model : object
            The Nunchaku FLUX model to modify.
        pulid_pipline : :class:`~nunchaku.pipeline.pipeline_flux_pulid.PuLIDPipeline`
            The PuLID pipeline instance.
        image : np.ndarray or torch.Tensor
            The input image for identity embedding extraction.
        weight : float
            The strength of the identity guidance.
        start_at : float
            The starting timestep for applying the effect.
        end_at : float
            The ending timestep for applying the effect.
        attn_mask : optional
            Not supported for now.
        options : optional
            Additional options (unused).
        unique_id : optional
            Unique identifier (unused).

        Returns
        -------
        tuple
            A tuple containing the modified model.

        Raises
        ------
        NotImplementedError
            If attn_mask is provided.
        """
        all_embeddings = []
        for i in range(image.shape[0]):
            single_image = image[i : i + 1].squeeze().cpu().numpy() * 255.0
            single_image = np.clip(single_image, 0, 255).astype(np.uint8)

            id_embedding, _ = pulid_pipline.get_id_embedding(single_image)
            if id_embedding is not None:
                all_embeddings.append(id_embedding)

        if not all_embeddings:
            logger.warning("Nunchaku PuLID: No face detected in any of the images. Skipping PuLID.")
            return (model,)

        id_embeddings = torch.mean(torch.stack(all_embeddings), dim=0)

        model_wrapper = model.model.diffusion_model
        assert isinstance(model_wrapper, ComfyFluxWrapper)
        transformer = model_wrapper.model

        model_wrapper.model = None
        ret_model = copy.deepcopy(model)  # copy everything except the model
        ret_model_wrapper = ret_model.model.diffusion_model
        assert isinstance(ret_model_wrapper, ComfyFluxWrapper)
        ret_model_wrapper.model = transformer
        model_wrapper.model = transformer

        ret_model_wrapper.pulid_pipeline = pulid_pipline
        ret_model_wrapper.customized_forward = partial(
            pulid_forward, id_embeddings=id_embeddings, id_weight=weight, start_timestep=start_at, end_timestep=end_at
        )

        if attn_mask is not None:
            raise NotImplementedError("Attn mask is not supported for now in Nunchaku FLUX PuLID Apply V2.")

        return (ret_model,)


class NunchakuPuLIDLoaderV2:
    """
    Node for loading the PuLID pipeline.

    This node loads the PuLID model, EVA CLIP model, and required face libraries, and
    returns both the original model and a ready-to-use PuLID pipeline.
    """

    @classmethod
    def INPUT_TYPES(s):
        """
        Defines the input types and tooltips for the node.

        Returns
        -------
        dict
            A dictionary specifying the required inputs and their descriptions for the node interface.
        """
        pulid_files = folder_paths.get_filename_list("pulid")
        clip_files = folder_paths.get_filename_list("clip")
        return {
            "required": {
                "model": ("MODEL", {"tooltip": "The nunchaku model."}),
                "pulid_file": (pulid_files, {"tooltip": "Path to the PuLID model."}),
                "eva_clip_file": (clip_files, {"tooltip": "Path to the EVA clip model."}),
                "insight_face_provider": (["gpu", "cpu"], {"default": "gpu", "tooltip": "InsightFace ONNX provider."}),
            }
        }

    RETURN_TYPES = ("MODEL", "PULID_PIPELINE")
    FUNCTION = "load"
    CATEGORY = "Nunchaku"
    TITLE = "Nunchaku PuLID Loader V2"

    def load(self, model, pulid_file: str, eva_clip_file: str, insight_face_provider: str):
        """
        Load the PuLID pipeline and associate it with the given Nunchaku FLUX model.

        Parameters
        ----------
        model : object
            The Nunchaku FLUX model to use.
        pulid_file : str
            Path to the PuLID model file.
        eva_clip_file : str
            Path to the EVA CLIP model file.
        insight_face_provider : str
            ONNX provider for InsightFace ("gpu" or "cpu").

        Returns
        -------
        tuple
            (model, pulid_pipeline)
        """
        model_wrapper = model.model.diffusion_model
        assert isinstance(model_wrapper, ComfyFluxWrapper)
        transformer = model_wrapper.model

        device = comfy.model_management.get_torch_device()
        weight_dtype = next(transformer.parameters()).dtype

        pulid_path = folder_paths.get_full_path_or_raise("pulid", pulid_file)
        eva_clip_path = folder_paths.get_full_path_or_raise("clip", eva_clip_file)
        insightface_dirpath = folder_paths.get_folder_paths("insightface")[0]
        facexlib_dirpath = folder_paths.get_folder_paths("facexlib")[0]

        pulid_pipline = PuLIDPipeline(
            dit=transformer,
            device=device,
            weight_dtype=weight_dtype,
            onnx_provider=insight_face_provider,
            pulid_path=pulid_path,
            eva_clip_path=eva_clip_path,
            insightface_dirpath=insightface_dirpath,
            facexlib_dirpath=facexlib_dirpath,
        )

        return (model, pulid_pipline)


class NunchakuPulidApply:
    """
    Deprecated node for applying PuLID to a Nunchaku FLUX model.

    Attributes
    ----------
    pulid_device : str
        The device to use for PuLID inference (default: "cuda").
    weight_dtype : torch.dtype
        The data type for model weights (default: torch.bfloat16).
    onnx_provider : str
        The ONNX provider for InsightFace ("gpu" or "cpu", default: "gpu").
    pretrained_model : object or None
        The loaded PuLID model, if any.

    .. warning::
        This node is deprecated and will be removed in December 2025.
        Please use :class:`NunchakuFluxPuLIDApplyV2` instead.
    """

    def __init__(self):
        self.pulid_device = "cuda"
        self.weight_dtype = torch.bfloat16
        self.onnx_provider = "gpu"
        self.pretrained_model = None

    @classmethod
    def INPUT_TYPES(s):
        """
        Defines the input types and tooltips for the node.

        Returns
        -------
        dict
            A dictionary specifying the required inputs and their descriptions for the node interface.
        """
        return {
            "required": {
                "pulid": ("PULID", {"tooltip": "from Nunchaku Pulid Loader"}),
                "image": ("IMAGE", {"tooltip": "The image to encode"}),
                "model": ("MODEL", {"tooltip": "The nunchaku model."}),
                "ip_weight": (
                    "FLOAT",
                    {
                        "default": 1.0,
                        "min": 0.0,
                        "max": 2.0,
                        "step": 0.01,
                        "tooltip": "ip_weight",
                    },
                ),
            }
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "apply"
    CATEGORY = "Nunchaku"
    TITLE = "Nunchaku Pulid Apply (Deprecated)"

    def apply(self, pulid, image, model, ip_weight):
        """
        Apply PuLID identity embeddings to the given Nunchaku FLUX model.

        Parameters
        ----------
        pulid : object
            The PuLID pipeline instance.
        image : torch.Tensor
            The image to encode for identity.
        model : object
            The Nunchaku FLUX model.
        ip_weight : float
            The weight for the identity embedding.

        Returns
        -------
        tuple
            The updated model with PuLID applied.
        """
        logger.warning(
            'This node is deprecated and will be removed in December 2025. Directly use "Nunchaku FLUX PuLID Apply V2" instead.'
        )

        image = image.squeeze().cpu().numpy() * 255.0
        image = np.clip(image, 0, 255).astype(np.uint8)
        id_embeddings, _ = pulid.get_id_embedding(image)
        model.model.diffusion_model.model.forward = MethodType(
            partial(pulid_forward, id_embeddings=id_embeddings, id_weight=ip_weight), model.model.diffusion_model.model
        )
        return (model,)


class NunchakuPulidLoader:
    """
    Deprecated node for loading the PuLID pipeline for a Nunchaku FLUX model.

    .. warning::
        This node is deprecated and will be removed in December 2025.
        Use :class:`NunchakuPuLIDLoaderV2` instead.

    Attributes
    ----------
    pulid_device : str
        Device to load the PuLID pipeline on (default: "cuda").
    weight_dtype : torch.dtype
        Data type for model weights (default: torch.bfloat16).
    onnx_provider : str
        ONNX provider to use (default: "gpu").
    pretrained_model : str or None
        Path to the pretrained PuLID model, if any.
    """

    def __init__(self):
        """
        Initialize the loader with default device, dtype, and ONNX provider.
        """
        self.pulid_device = "cuda"
        self.weight_dtype = torch.bfloat16
        self.onnx_provider = "gpu"
        self.pretrained_model = None

    @classmethod
    def INPUT_TYPES(s):
        """
        Returns the required input types for this node.

        Returns
        -------
        dict
            Dictionary specifying required inputs.
        """
        return {
            "required": {
                "model": ("MODEL", {"tooltip": "The nunchaku model."}),
            }
        }

    RETURN_TYPES = ("MODEL", "PULID")
    FUNCTION = "load"
    CATEGORY = "Nunchaku"
    TITLE = "Nunchaku Pulid Loader (Deprecated)"

    def load(self, model):
        """
        Load the PuLID pipeline for the given Nunchaku FLUX model.

        .. warning::
            This node is deprecated and will be removed in December 2025.
            Use :class:`NunchakuPuLIDLoaderV2` instead.

        Parameters
        ----------
        model : object
            The Nunchaku FLUX model.

        Returns
        -------
        tuple
            The input model and the loaded PuLID pipeline.
        """
        logger.warning(
            'This node is deprecated and will be removed in December 2025. Directly use "Nunchaku PuLID Loader V22 instead.'
        )
        pulid_model = PuLIDPipeline(
            dit=model.model.diffusion_model.model,
            device=self.pulid_device,
            weight_dtype=self.weight_dtype,
            onnx_provider=self.onnx_provider,
        )
        pulid_model.load_pretrain(self.pretrained_model)

        return (model, pulid_model)