modular_blocks.py 15.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
18
19
20
from .before_denoise import (
    FluxImg2ImgPrepareLatentsStep,
    FluxImg2ImgSetTimestepsStep,
21
    FluxKontextRoPEInputsStep,
22
    FluxPrepareLatentsStep,
23
    FluxRoPEInputsStep,
24
25
    FluxSetTimestepsStep,
)
26
from .decoders import FluxDecodeStep
27
28
29
30
31
32
33
34
35
36
37
38
39
from .denoise import FluxDenoiseStep, FluxKontextDenoiseStep
from .encoders import (
    FluxKontextProcessImagesInputStep,
    FluxProcessImagesInputStep,
    FluxTextEncoderStep,
    FluxVaeEncoderDynamicStep,
)
from .inputs import (
    FluxInputsDynamicStep,
    FluxKontextInputsDynamicStep,
    FluxKontextSetResolutionStep,
    FluxTextInputStep,
)
40
41
42
43
44


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


45
# vae encoder (run before before_denoise)
46
FluxImg2ImgVaeEncoderBlocks = InsertableDict(
47
    [("preprocess", FluxProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep())]
48
49
50
51
52
53
54
55
56
57
58
59
60
61
)


class FluxImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
    model_name = "flux"

    block_classes = FluxImg2ImgVaeEncoderBlocks.values()
    block_names = FluxImg2ImgVaeEncoderBlocks.keys()

    @property
    def description(self) -> str:
        return "Vae encoder step that preprocess andencode the image inputs into their latent representations."


62
class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
63
    block_classes = [FluxImg2ImgVaeEncoderStep]
64
65
66
67
68
69
70
71
    block_names = ["img2img"]
    block_trigger_inputs = ["image"]

    @property
    def description(self):
        return (
            "Vae encoder step that encode the image inputs into their latent representations.\n"
            + "This is an auto pipeline block that works for img2img tasks.\n"
72
73
            + " - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided."
            + " - if `image` is not provided, step will be skipped."
74
75
76
        )


77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Flux Kontext vae encoder (run before before_denoise)

FluxKontextVaeEncoderBlocks = InsertableDict(
    [("preprocess", FluxKontextProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep(sample_mode="argmax"))]
)


class FluxKontextVaeEncoderStep(SequentialPipelineBlocks):
    model_name = "flux-kontext"

    block_classes = FluxKontextVaeEncoderBlocks.values()
    block_names = FluxKontextVaeEncoderBlocks.keys()

    @property
    def description(self) -> str:
        return "Vae encoder step that preprocess andencode the image inputs into their latent representations."


class FluxKontextAutoVaeEncoderStep(AutoPipelineBlocks):
    block_classes = [FluxKontextVaeEncoderStep]
    block_names = ["img2img"]
    block_trigger_inputs = ["image"]

    @property
    def description(self):
        return (
            "Vae encoder step that encode the image inputs into their latent representations.\n"
            + "This is an auto pipeline block that works for img2img tasks.\n"
            + " - `FluxKontextVaeEncoderStep` (img2img) is used when only `image` is provided."
            + " - if `image` is not provided, step will be skipped."
        )


110
111
112
113
114
115
# before_denoise: text2img
FluxBeforeDenoiseBlocks = InsertableDict(
    [
        ("prepare_latents", FluxPrepareLatentsStep()),
        ("set_timesteps", FluxSetTimestepsStep()),
        ("prepare_rope_inputs", FluxRoPEInputsStep()),
116
    ]
117
118
119
120
121
122
)


class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
    block_classes = FluxBeforeDenoiseBlocks.values()
    block_names = FluxBeforeDenoiseBlocks.keys()
123
124
125

    @property
    def description(self):
126
        return "Before denoise step that prepares the inputs for the denoise step in text-to-image generation."
127
128


129
# before_denoise: img2img
130
131
132
133
134
135
136
137
138
139
FluxImg2ImgBeforeDenoiseBlocks = InsertableDict(
    [
        ("prepare_latents", FluxPrepareLatentsStep()),
        ("set_timesteps", FluxImg2ImgSetTimestepsStep()),
        ("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
        ("prepare_rope_inputs", FluxRoPEInputsStep()),
    ]
)


140
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
141
142
    block_classes = FluxImg2ImgBeforeDenoiseBlocks.values()
    block_names = FluxImg2ImgBeforeDenoiseBlocks.keys()
143
144
145

    @property
    def description(self):
146
        return "Before denoise step that prepare the inputs for the denoise step for img2img task."
147
148
149


# before_denoise: all task (text2img, img2img)
150
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
151
    model_name = "flux-kontext"
152
153
154
    block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep]
    block_names = ["img2img", "text2image"]
    block_trigger_inputs = ["image_latents", None]
155
156
157
158
159
160
161

    @property
    def description(self):
        return (
            "Before denoise step that prepare the inputs for the denoise step.\n"
            + "This is an auto pipeline block that works for text2image.\n"
            + " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
162
            + " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
163
164
165
        )


166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# before_denoise: FluxKontext

FluxKontextBeforeDenoiseBlocks = InsertableDict(
    [
        ("prepare_latents", FluxPrepareLatentsStep()),
        ("set_timesteps", FluxSetTimestepsStep()),
        ("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
    ]
)


class FluxKontextBeforeDenoiseStep(SequentialPipelineBlocks):
    block_classes = FluxKontextBeforeDenoiseBlocks.values()
    block_names = FluxKontextBeforeDenoiseBlocks.keys()

    @property
    def description(self):
        return (
            "Before denoise step that prepare the inputs for the denoise step\n"
            "for img2img/text2img task for Flux Kontext."
        )


class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks):
    block_classes = [FluxKontextBeforeDenoiseStep, FluxBeforeDenoiseStep]
    block_names = ["img2img", "text2image"]
    block_trigger_inputs = ["image_latents", None]

    @property
    def description(self):
        return (
            "Before denoise step that prepare the inputs for the denoise step.\n"
            + "This is an auto pipeline block that works for text2image.\n"
            + " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
            + " - `FluxKontextBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
        )


204
205
206
207
208
209
210
211
212
213
# denoise: text2image
class FluxAutoDenoiseStep(AutoPipelineBlocks):
    block_classes = [FluxDenoiseStep]
    block_names = ["denoise"]
    block_trigger_inputs = [None]

    @property
    def description(self) -> str:
        return (
            "Denoise step that iteratively denoise the latents. "
214
215
            "This is a auto pipeline block that works for text2image and img2img tasks."
            " - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
216
217
218
        )


219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# denoise: Flux Kontext


class FluxKontextAutoDenoiseStep(AutoPipelineBlocks):
    block_classes = [FluxKontextDenoiseStep]
    block_names = ["denoise"]
    block_trigger_inputs = [None]

    @property
    def description(self) -> str:
        return (
            "Denoise step that iteratively denoise the latents for Flux Kontext. "
            "This is a auto pipeline block that works for text2image and img2img tasks."
            " - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
        )


236
# decode: all task (text2img, img2img)
237
238
239
240
241
242
243
class FluxAutoDecodeStep(AutoPipelineBlocks):
    block_classes = [FluxDecodeStep]
    block_names = ["non-inpaint"]
    block_trigger_inputs = [None]

    @property
    def description(self):
244
        return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
245
246


247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
# inputs: text2image/img2img
FluxImg2ImgBlocks = InsertableDict(
    [("text_inputs", FluxTextInputStep()), ("additional_inputs", FluxInputsDynamicStep())]
)


class FluxImg2ImgInputStep(SequentialPipelineBlocks):
    model_name = "flux"
    block_classes = FluxImg2ImgBlocks.values()
    block_names = FluxImg2ImgBlocks.keys()

    @property
    def description(self):
        return "Input step that prepares the inputs for the img2img denoising step. It:\n"
        " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
        " - update height/width based `image_latents`, patchify `image_latents`."


265
class FluxAutoInputStep(AutoPipelineBlocks):
266
267
268
269
270
271
272
273
274
275
276
277
278
279
    block_classes = [FluxImg2ImgInputStep, FluxTextInputStep]
    block_names = ["img2img", "text2image"]
    block_trigger_inputs = ["image_latents", None]

    @property
    def description(self):
        return (
            "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
            " This is an auto pipeline block that works for text2image/img2img tasks.\n"
            + " - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
            + " - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.\n"
        )


280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
# inputs: Flux Kontext

FluxKontextBlocks = InsertableDict(
    [
        ("set_resolution", FluxKontextSetResolutionStep()),
        ("text_inputs", FluxTextInputStep()),
        ("additional_inputs", FluxKontextInputsDynamicStep()),
    ]
)


class FluxKontextInputStep(SequentialPipelineBlocks):
    model_name = "flux-kontext"
    block_classes = FluxKontextBlocks.values()
    block_names = FluxKontextBlocks.keys()

    @property
    def description(self):
        return (
            "Input step that prepares the inputs for the both text2img and img2img denoising step. It:\n"
            " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
            " - update height/width based `image_latents`, patchify `image_latents`."
        )


class FluxKontextAutoInputStep(AutoPipelineBlocks):
    block_classes = [FluxKontextInputStep, FluxTextInputStep]
    # block_classes = [FluxKontextInputStep]
    block_names = ["img2img", "text2img"]
    # block_names = ["img2img"]
    block_trigger_inputs = ["image_latents", None]
    # block_trigger_inputs = ["image_latents"]

    @property
    def description(self):
        return (
            "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
            " This is an auto pipeline block that works for text2image/img2img tasks.\n"
            + " - `FluxKontextInputStep` (img2img) is used when `image_latents` is provided.\n"
            + " - `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present."
        )


323
class FluxCoreDenoiseStep(SequentialPipelineBlocks):
324
    model_name = "flux"
325
    block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
326
327
328
329
330
331
    block_names = ["input", "before_denoise", "denoise"]

    @property
    def description(self):
        return (
            "Core step that performs the denoising process. \n"
332
            + " - `FluxAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
333
334
            + " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
            + " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
335
            + "This step supports text-to-image and image-to-image tasks for Flux:\n"
336
            + " - for image-to-image generation, you need to provide `image_latents`\n"
337
            + " - for text-to-image generation, all you need to provide is prompt embeddings."
338
339
340
        )


341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
    model_name = "flux-kontext"
    block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextAutoDenoiseStep]
    block_names = ["input", "before_denoise", "denoise"]

    @property
    def description(self):
        return (
            "Core step that performs the denoising process. \n"
            + " - `FluxKontextAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
            + " - `FluxKontextAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
            + " - `FluxKontextAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
            + "This step supports text-to-image and image-to-image tasks for Flux:\n"
            + " - for image-to-image generation, you need to provide `image_latents`\n"
            + " - for text-to-image generation, all you need to provide is prompt embeddings."
        )


359
360
361
362
363
364
365
# Auto blocks (text2image and img2img)
AUTO_BLOCKS = InsertableDict(
    [
        ("text_encoder", FluxTextEncoderStep()),
        ("image_encoder", FluxAutoVaeEncoderStep()),
        ("denoise", FluxCoreDenoiseStep()),
        ("decode", FluxDecodeStep()),
366
    ]
367
368
)

369
370
371
372
373
374
375
376
377
AUTO_BLOCKS_KONTEXT = InsertableDict(
    [
        ("text_encoder", FluxTextEncoderStep()),
        ("image_encoder", FluxKontextAutoVaeEncoderStep()),
        ("denoise", FluxKontextCoreDenoiseStep()),
        ("decode", FluxDecodeStep()),
    ]
)

378
379
380
381
382
383

class FluxAutoBlocks(SequentialPipelineBlocks):
    model_name = "flux"

    block_classes = AUTO_BLOCKS.values()
    block_names = AUTO_BLOCKS.keys()
384
385
386
387

    @property
    def description(self):
        return (
388
389
390
            "Auto Modular pipeline for text-to-image and image-to-image using Flux.\n"
            + "- for text-to-image generation, all you need to provide is `prompt`\n"
            + "- for image-to-image generation, you need to provide either `image` or `image_latents`"
391
392
393
        )


394
395
396
397
398
399
400
class FluxKontextAutoBlocks(FluxAutoBlocks):
    model_name = "flux-kontext"

    block_classes = AUTO_BLOCKS_KONTEXT.values()
    block_names = AUTO_BLOCKS_KONTEXT.keys()


401
402
TEXT2IMAGE_BLOCKS = InsertableDict(
    [
403
404
405
406
407
408
409
        ("text_encoder", FluxTextEncoderStep()),
        ("input", FluxTextInputStep()),
        ("prepare_latents", FluxPrepareLatentsStep()),
        ("set_timesteps", FluxSetTimestepsStep()),
        ("prepare_rope_inputs", FluxRoPEInputsStep()),
        ("denoise", FluxDenoiseStep()),
        ("decode", FluxDecodeStep()),
410
411
412
    ]
)

413
414
IMAGE2IMAGE_BLOCKS = InsertableDict(
    [
415
416
417
418
419
420
421
422
423
        ("text_encoder", FluxTextEncoderStep()),
        ("vae_encoder", FluxVaeEncoderDynamicStep()),
        ("input", FluxImg2ImgInputStep()),
        ("prepare_latents", FluxPrepareLatentsStep()),
        ("set_timesteps", FluxImg2ImgSetTimestepsStep()),
        ("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
        ("prepare_rope_inputs", FluxRoPEInputsStep()),
        ("denoise", FluxDenoiseStep()),
        ("decode", FluxDecodeStep()),
424
425
    ]
)
426

427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
FLUX_KONTEXT_BLOCKS = InsertableDict(
    [
        ("text_encoder", FluxTextEncoderStep()),
        ("vae_encoder", FluxVaeEncoderDynamicStep(sample_mode="argmax")),
        ("input", FluxKontextInputStep()),
        ("prepare_latents", FluxPrepareLatentsStep()),
        ("set_timesteps", FluxSetTimestepsStep()),
        ("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
        ("denoise", FluxKontextDenoiseStep()),
        ("decode", FluxDecodeStep()),
    ]
)

ALL_BLOCKS = {
    "text2image": TEXT2IMAGE_BLOCKS,
    "img2img": IMAGE2IMAGE_BLOCKS,
    "auto": AUTO_BLOCKS,
    "auto_kontext": AUTO_BLOCKS_KONTEXT,
    "kontext": FLUX_KONTEXT_BLOCKS,
}