node_utils.py 24.8 KB
Newer Older
YiYi Xu's avatar
YiYi Xu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
import json
import logging
import os
from pathlib import Path
from typing import List, Optional, Tuple, Union

import numpy as np
import PIL
import torch

from ..configuration_utils import ConfigMixin
from ..image_processor import PipelineImageInput
from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
from .modular_pipeline_utils import InputParam


logger = logging.getLogger(__name__)

# YiYi Notes: this is actually for SDXL, put it here for now
SDXL_INPUTS_SCHEMA = {
    "prompt": InputParam(
        "prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
    ),
    "prompt_2": InputParam(
        "prompt_2",
        type_hint=Union[str, List[str]],
        description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
    ),
    "negative_prompt": InputParam(
        "negative_prompt",
        type_hint=Union[str, List[str]],
        description="The prompt or prompts not to guide the image generation",
    ),
    "negative_prompt_2": InputParam(
        "negative_prompt_2",
        type_hint=Union[str, List[str]],
        description="The negative prompt or prompts for text_encoder_2",
    ),
    "cross_attention_kwargs": InputParam(
        "cross_attention_kwargs",
        type_hint=Optional[dict],
        description="Kwargs dictionary passed to the AttentionProcessor",
    ),
    "clip_skip": InputParam(
        "clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
    ),
    "image": InputParam(
        "image",
        type_hint=PipelineImageInput,
        required=True,
        description="The image(s) to modify for img2img or inpainting",
    ),
    "mask_image": InputParam(
        "mask_image",
        type_hint=PipelineImageInput,
        required=True,
        description="Mask image for inpainting, white pixels will be repainted",
    ),
    "generator": InputParam(
        "generator",
        type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
        description="Generator(s) for deterministic generation",
    ),
    "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
    "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
    "num_images_per_prompt": InputParam(
        "num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
    ),
    "num_inference_steps": InputParam(
        "num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
    ),
    "timesteps": InputParam(
        "timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
    ),
    "sigmas": InputParam(
        "sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
    ),
    "denoising_end": InputParam(
        "denoising_end",
        type_hint=Optional[float],
        description="Fraction of denoising process to complete before termination",
    ),
    # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
    "strength": InputParam(
        "strength", type_hint=float, default=0.3, description="How much to transform the reference image"
    ),
    "denoising_start": InputParam(
        "denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
    ),
    "latents": InputParam(
        "latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
    ),
    "padding_mask_crop": InputParam(
        "padding_mask_crop",
        type_hint=Optional[Tuple[int, int]],
        description="Size of margin in crop for image and mask",
    ),
    "original_size": InputParam(
        "original_size",
        type_hint=Optional[Tuple[int, int]],
        description="Original size of the image for SDXL's micro-conditioning",
    ),
    "target_size": InputParam(
        "target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
    ),
    "negative_original_size": InputParam(
        "negative_original_size",
        type_hint=Optional[Tuple[int, int]],
        description="Negative conditioning based on image resolution",
    ),
    "negative_target_size": InputParam(
        "negative_target_size",
        type_hint=Optional[Tuple[int, int]],
        description="Negative conditioning based on target resolution",
    ),
    "crops_coords_top_left": InputParam(
        "crops_coords_top_left",
        type_hint=Tuple[int, int],
        default=(0, 0),
        description="Top-left coordinates for SDXL's micro-conditioning",
    ),
    "negative_crops_coords_top_left": InputParam(
        "negative_crops_coords_top_left",
        type_hint=Tuple[int, int],
        default=(0, 0),
        description="Negative conditioning crop coordinates",
    ),
    "aesthetic_score": InputParam(
        "aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
    ),
    "negative_aesthetic_score": InputParam(
        "negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
    ),
    "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
    "output_type": InputParam(
        "output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
    ),
    "ip_adapter_image": InputParam(
        "ip_adapter_image",
        type_hint=PipelineImageInput,
        required=True,
        description="Image(s) to be used as IP adapter",
    ),
    "control_image": InputParam(
        "control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
    ),
    "control_guidance_start": InputParam(
        "control_guidance_start",
        type_hint=Union[float, List[float]],
        default=0.0,
        description="When ControlNet starts applying",
    ),
    "control_guidance_end": InputParam(
        "control_guidance_end",
        type_hint=Union[float, List[float]],
        default=1.0,
        description="When ControlNet stops applying",
    ),
    "controlnet_conditioning_scale": InputParam(
        "controlnet_conditioning_scale",
        type_hint=Union[float, List[float]],
        default=1.0,
        description="Scale factor for ControlNet outputs",
    ),
    "guess_mode": InputParam(
        "guess_mode",
        type_hint=bool,
        default=False,
        description="Enables ControlNet encoder to recognize input without prompts",
    ),
    "control_mode": InputParam(
        "control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
    ),
}

SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
    "prompt_embeds": InputParam(
        "prompt_embeds",
        type_hint=torch.Tensor,
        required=True,
        description="Text embeddings used to guide image generation",
    ),
    "negative_prompt_embeds": InputParam(
        "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
    ),
    "pooled_prompt_embeds": InputParam(
        "pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
    ),
    "negative_pooled_prompt_embeds": InputParam(
        "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
    ),
    "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
    "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
    "preprocess_kwargs": InputParam(
        "preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
    ),
    "latents": InputParam(
        "latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
    ),
    "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
    "num_inference_steps": InputParam(
        "num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
    ),
    "latent_timestep": InputParam(
        "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
    ),
    "image_latents": InputParam(
        "image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
    ),
    "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
    "masked_image_latents": InputParam(
        "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
    ),
    "add_time_ids": InputParam(
        "add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
    ),
    "negative_add_time_ids": InputParam(
        "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
    ),
    "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
    "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
    "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
    "ip_adapter_embeds": InputParam(
        "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
    ),
    "negative_ip_adapter_embeds": InputParam(
        "negative_ip_adapter_embeds",
        type_hint=List[torch.Tensor],
        description="Negative image embeddings for IP-Adapter",
    ),
    "images": InputParam(
        "images",
        type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
        required=True,
        description="Generated images",
    ),
}

SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}


DEFAULT_PARAM_MAPS = {
    "prompt": {
        "label": "Prompt",
        "type": "string",
        "default": "a bear sitting in a chair drinking a milkshake",
        "display": "textarea",
    },
    "negative_prompt": {
        "label": "Negative Prompt",
        "type": "string",
        "default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
        "display": "textarea",
    },
    "num_inference_steps": {
        "label": "Steps",
        "type": "int",
        "default": 25,
        "min": 1,
        "max": 1000,
    },
    "seed": {
        "label": "Seed",
        "type": "int",
        "default": 0,
        "min": 0,
        "display": "random",
    },
    "width": {
        "label": "Width",
        "type": "int",
        "display": "text",
        "default": 1024,
        "min": 8,
        "max": 8192,
        "step": 8,
        "group": "dimensions",
    },
    "height": {
        "label": "Height",
        "type": "int",
        "display": "text",
        "default": 1024,
        "min": 8,
        "max": 8192,
        "step": 8,
        "group": "dimensions",
    },
    "images": {
        "label": "Images",
        "type": "image",
        "display": "output",
    },
    "image": {
        "label": "Image",
        "type": "image",
        "display": "input",
    },
}

DEFAULT_TYPE_MAPS = {
    "int": {
        "type": "int",
        "default": 0,
        "min": 0,
    },
    "float": {
        "type": "float",
        "default": 0.0,
        "min": 0.0,
    },
    "str": {
        "type": "string",
        "default": "",
    },
    "bool": {
        "type": "boolean",
        "default": False,
    },
    "image": {
        "type": "image",
    },
}

DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
DEFAULT_CATEGORY = "Modular Diffusers"
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
DEFAULT_PARAMS_GROUPS_KEYS = {
    "text_encoders": ["text_encoder", "tokenizer"],
    "ip_adapter_embeds": ["ip_adapter_embeds"],
    "prompt_embeddings": ["prompt_embeds"],
}


def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
    """
    Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" ->
    "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
    """
    if name is None:
        return None
    for group_name, group_keys in group_params_keys.items():
        for group_key in group_keys:
            if group_key in name:
                return group_name
    return None


class ModularNode(ConfigMixin):
    """
    A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
    around a ModularPipelineBlocks object.

    <Tip warning={true}>

        This is an experimental feature and is likely to change in the future.

    </Tip>
    """

    config_name = "node_config.json"

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: str,
        trust_remote_code: Optional[bool] = None,
        **kwargs,
    ):
        blocks = ModularPipelineBlocks.from_pretrained(
            pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
        )
        return cls(blocks, **kwargs)

    def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
        self.blocks = blocks

        if label is None:
            label = self.blocks.__class__.__name__
        # blocks param name -> mellon param name
        self.name_mapping = {}

        input_params = {}
        # pass or create a default param dict for each input
        # e.g. for prompt,
        #       prompt = {
co63oc's avatar
co63oc committed
387
        #               "name": "text_input", # the name of the input in node definition, could be different from the input name in diffusers
YiYi Xu's avatar
YiYi Xu committed
388
389
390
391
392
393
        #               "label": "Prompt",
        #               "type": "string",
        #               "default": "a bear sitting in a chair drinking a milkshake",
        #               "display": "textarea"}
        # if type is not specified, it'll be a "custom" param of its own type
        # e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
co63oc's avatar
co63oc committed
394
        #  it will get this spec in node definition {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
YiYi Xu's avatar
YiYi Xu committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
        #  name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
        inputs = self.blocks.inputs + self.blocks.intermediate_inputs
        for inp in inputs:
            param = kwargs.pop(inp.name, None)
            if param:
                # user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
                input_params[inp.name] = param
                mellon_name = param.pop("name", inp.name)
                if mellon_name != inp.name:
                    self.name_mapping[inp.name] = mellon_name
                continue

            if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
                continue

            if inp.name in DEFAULT_PARAM_MAPS:
                # first check if it's in the default param map, if so, directly use that
                param = DEFAULT_PARAM_MAPS[inp.name].copy()
            elif get_group_name(inp.name):
                param = get_group_name(inp.name)
                if inp.name not in self.name_mapping:
                    self.name_mapping[inp.name] = param
            else:
                # if not, check if it's in the SDXL input schema, if so,
                # 1. use the type hint to determine the type
                # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
                if inp.type_hint is not None:
                    type_str = str(inp.type_hint).lower()
                else:
                    inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
                    type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
                for type_key, type_param in DEFAULT_TYPE_MAPS.items():
                    if type_key in type_str:
                        param = type_param.copy()
                        param["label"] = inp.name
                        param["display"] = "input"
                        break
                else:
                    param = inp.name
            # add the param dict to the inp_params dict
            input_params[inp.name] = param

        component_params = {}
        for comp in self.blocks.expected_components:
            param = kwargs.pop(comp.name, None)
            if param:
                component_params[comp.name] = param
                mellon_name = param.pop("name", comp.name)
                if mellon_name != comp.name:
                    self.name_mapping[comp.name] = mellon_name
                continue

            to_exclude = False
            for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
                if exclude_key in comp.name:
                    to_exclude = True
                    break
            if to_exclude:
                continue

            if get_group_name(comp.name):
                param = get_group_name(comp.name)
                if comp.name not in self.name_mapping:
                    self.name_mapping[comp.name] = param
            elif comp.name in DEFAULT_MODEL_KEYS:
                param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
            else:
                param = comp.name
            # add the param dict to the model_params dict
            component_params[comp.name] = param

        output_params = {}
        if isinstance(self.blocks, SequentialPipelineBlocks):
            last_block_name = list(self.blocks.sub_blocks.keys())[-1]
            outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs
        else:
            outputs = self.blocks.intermediate_outputs

        for out in outputs:
            param = kwargs.pop(out.name, None)
            if param:
                output_params[out.name] = param
                mellon_name = param.pop("name", out.name)
                if mellon_name != out.name:
                    self.name_mapping[out.name] = mellon_name
                continue

            if out.name in DEFAULT_PARAM_MAPS:
                param = DEFAULT_PARAM_MAPS[out.name].copy()
                param["display"] = "output"
            else:
                group_name = get_group_name(out.name)
                if group_name:
                    param = group_name
                    if out.name not in self.name_mapping:
                        self.name_mapping[out.name] = param
                else:
                    param = out.name
            # add the param dict to the outputs dict
            output_params[out.name] = param

        if len(kwargs) > 0:
            logger.warning(f"Unused kwargs: {kwargs}")

        register_dict = {
            "category": category,
            "label": label,
            "input_params": input_params,
            "component_params": component_params,
            "output_params": output_params,
            "name_mapping": self.name_mapping,
        }
        self.register_to_config(**register_dict)

    def setup(self, components_manager, collection=None):
        self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection)
        self._components_manager = components_manager

    @property
    def mellon_config(self):
        return self._convert_to_mellon_config()

    def _convert_to_mellon_config(self):
        node = {}
        node["label"] = self.config.label
        node["category"] = self.config.category

        node_param = {}
        for inp_name, inp_param in self.config.input_params.items():
            if inp_name in self.name_mapping:
                mellon_name = self.name_mapping[inp_name]
            else:
                mellon_name = inp_name
            if isinstance(inp_param, str):
                param = {
                    "label": inp_param,
                    "type": inp_param,
                    "display": "input",
                }
            else:
                param = inp_param

            if mellon_name not in node_param:
                node_param[mellon_name] = param
            else:
                logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")

        for comp_name, comp_param in self.config.component_params.items():
            if comp_name in self.name_mapping:
                mellon_name = self.name_mapping[comp_name]
            else:
                mellon_name = comp_name
            if isinstance(comp_param, str):
                param = {
                    "label": comp_param,
                    "type": comp_param,
                    "display": "input",
                }
            else:
                param = comp_param

            if mellon_name not in node_param:
                node_param[mellon_name] = param
            else:
                logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")

        for out_name, out_param in self.config.output_params.items():
            if out_name in self.name_mapping:
                mellon_name = self.name_mapping[out_name]
            else:
                mellon_name = out_name
            if isinstance(out_param, str):
                param = {
                    "label": out_param,
                    "type": out_param,
                    "display": "output",
                }
            else:
                param = out_param

            if mellon_name not in node_param:
                node_param[mellon_name] = param
            else:
                logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
        node["params"] = node_param
        return node

    def save_mellon_config(self, file_path):
        """
        Save the Mellon configuration to a JSON file.

        Args:
            file_path (str or Path): Path where the JSON file will be saved

        Returns:
            Path: Path to the saved config file
        """
        file_path = Path(file_path)

        # Create directory if it doesn't exist
        os.makedirs(file_path.parent, exist_ok=True)

        # Create a combined dictionary with module definition and name mapping
        config = {"module": self.mellon_config, "name_mapping": self.name_mapping}

        # Save the config to file
        with open(file_path, "w", encoding="utf-8") as f:
            json.dump(config, f, indent=2)

        logger.info(f"Mellon config and name mapping saved to {file_path}")

        return file_path

    @classmethod
    def load_mellon_config(cls, file_path):
        """
        Load a Mellon configuration from a JSON file.

        Args:
            file_path (str or Path): Path to the JSON file containing Mellon config

        Returns:
            dict: The loaded combined configuration containing 'module' and 'name_mapping'
        """
        file_path = Path(file_path)

        if not file_path.exists():
            raise FileNotFoundError(f"Config file not found: {file_path}")

        with open(file_path, "r", encoding="utf-8") as f:
            config = json.load(f)

        logger.info(f"Mellon config loaded from {file_path}")

        return config

    def process_inputs(self, **kwargs):
        params_components = {}
        for comp_name, comp_param in self.config.component_params.items():
            logger.debug(f"component: {comp_name}")
            mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
            if mellon_comp_name in kwargs:
                if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
                    comp = kwargs[mellon_comp_name].pop(comp_name)
                else:
                    comp = kwargs.pop(mellon_comp_name)
                if comp:
                    params_components[comp_name] = self._components_manager.get_one(comp["model_id"])

        params_run = {}
        for inp_name, inp_param in self.config.input_params.items():
            logger.debug(f"input: {inp_name}")
            mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
            if mellon_inp_name in kwargs:
                if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
                    inp = kwargs[mellon_inp_name].pop(inp_name)
                else:
                    inp = kwargs.pop(mellon_inp_name)
                if inp is not None:
                    params_run[inp_name] = inp

        return_output_names = list(self.config.output_params.keys())

        return params_components, params_run, return_output_names

    def execute(self, **kwargs):
        params_components, params_run, return_output_names = self.process_inputs(**kwargs)

        self.pipeline.update_components(**params_components)
        output = self.pipeline(**params_run, output=return_output_names)
        return output