encoders.py 19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import html
from typing import List, Optional, Union

import regex as re
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

22
from ...configuration_utils import FrozenDict
23
from ...image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
24
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
25
from ...models import AutoencoderKL
26
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
27
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
28
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from .modular_pipeline import FluxModularPipeline


if is_ftfy_available():
    import ftfy


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()


def whitespace_clean(text):
    text = re.sub(r"\s+", " ", text)
    text = text.strip()
    return text


def prompt_clean(text):
    text = whitespace_clean(basic_clean(text))
    return text


56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
    encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
        return encoder_output.latent_dist.sample(generator)
    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
        return encoder_output.latent_dist.mode()
    elif hasattr(encoder_output, "latents"):
        return encoder_output.latents
    else:
        raise AttributeError("Could not access latents of provided encoder_output")


70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.Generator, sample_mode="sample"):
    if isinstance(generator, list):
        image_latents = [
            retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode)
            for i in range(image.shape[0])
        ]
        image_latents = torch.cat(image_latents, dim=0)
    else:
        image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode)

    image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor

    return image_latents


class FluxProcessImagesInputStep(ModularPipelineBlocks):
86
    model_name = "flux"
87
88
89

    @property
    def description(self) -> str:
90
        return "Image Preprocess step."
91
92
93
94
95
96
97

    @property
    def expected_components(self) -> List[ComponentSpec]:
        return [
            ComponentSpec(
                "image_processor",
                VaeImageProcessor,
98
                config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
99
100
101
102
103
104
                default_creation_method="from_config",
            ),
        ]

    @property
    def inputs(self) -> List[InputParam]:
105
        return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")]
106
107
108

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
109
        return [OutputParam(name="processed_image")]
110
111

    @staticmethod
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    def check_inputs(height, width, vae_scale_factor):
        if height is not None and height % (vae_scale_factor * 2) != 0:
            raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")

        if width is not None and width % (vae_scale_factor * 2) != 0:
            raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")

    @torch.no_grad()
    def __call__(self, components: FluxModularPipeline, state: PipelineState):
        block_state = self.get_block_state(state)

        if block_state.resized_image is None and block_state.image is None:
            raise ValueError("`resized_image` and `image` cannot be None at the same time")

        if block_state.resized_image is None:
            image = block_state.image
            self.check_inputs(
                height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
            )
            height = block_state.height or components.default_height
            width = block_state.width or components.default_width
133
        else:
134
135
136
137
138
139
140
            width, height = block_state.resized_image[0].size
            image = block_state.resized_image

        block_state.processed_image = components.image_processor.preprocess(image=image, height=height, width=width)

        self.set_block_state(state, block_state)
        return components, state
141
142


143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
    model_name = "flux-kontext"

    @property
    def description(self) -> str:
        return (
            "Image preprocess step for Flux Kontext. The preprocessed image goes to the VAE.\n"
            "Kontext works as a T2I model, too, in case no input image is provided."
        )

    @property
    def expected_components(self) -> List[ComponentSpec]:
        return [
            ComponentSpec(
                "image_processor",
                VaeImageProcessor,
                config=FrozenDict({"vae_scale_factor": 16}),
                default_creation_method="from_config",
            ),
        ]

    @property
    def inputs(self) -> List[InputParam]:
166
        return [InputParam("image"), InputParam("_auto_resize", type_hint=bool, default=True)]
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
        return [OutputParam(name="processed_image")]

    @torch.no_grad()
    def __call__(self, components: FluxModularPipeline, state: PipelineState):
        from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS

        block_state = self.get_block_state(state)
        images = block_state.image

        if images is None:
            block_state.processed_image = None

        else:
            multiple_of = components.image_processor.config.vae_scale_factor

            if not is_valid_image_imagelist(images):
                raise ValueError(f"Images must be image or list of images but are {type(images)}")

            if is_valid_image(images):
                images = [images]

            img = images[0]
            image_height, image_width = components.image_processor.get_default_height_width(img)
            aspect_ratio = image_width / image_height
194
195
            _auto_resize = block_state._auto_resize
            if _auto_resize:
196
197
198
199
200
201
202
203
204
205
206
207
208
                # Kontext is trained on specific resolutions, using one of them is recommended
                _, image_width, image_height = min(
                    (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
                )
            image_width = image_width // multiple_of * multiple_of
            image_height = image_height // multiple_of * multiple_of
            images = components.image_processor.resize(images, image_height, image_width)
            block_state.processed_image = components.image_processor.preprocess(images, image_height, image_width)

        self.set_block_state(state, block_state)
        return components, state


209
210
211
212
class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
    model_name = "flux"

    def __init__(
213
        self, input_name: str = "processed_image", output_name: str = "image_latents", sample_mode: str = "sample"
214
215
216
217
218
219
220
221
222
223
224
    ):
        """Initialize a VAE encoder step for converting images to latent representations.

        Both the input and output names are configurable so this block can be configured to process to different image
        inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents").

        Args:
            input_name (str, optional): Name of the input image tensor. Defaults to "processed_image".
                Examples: "processed_image" or "processed_control_image"
            output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
                Examples: "image_latents" or "control_image_latents"
225
            sample_mode (str, optional): Sampling mode to be used.
226
227
228
229
230
231
232
233
234
235

        Examples:
            # Basic usage with default settings (includes image processor): # FluxImageVaeEncoderDynamicStep()

            # Custom input/output names for control image: # FluxImageVaeEncoderDynamicStep(
                input_name="processed_control_image", output_name="control_image_latents"
            )
        """
        self._image_input_name = input_name
        self._image_latents_output_name = output_name
236
        self.sample_mode = sample_mode
237
238
239
240
241
242
243
244
245
246
247
248
249
        super().__init__()

    @property
    def description(self) -> str:
        return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"

    @property
    def expected_components(self) -> List[ComponentSpec]:
        components = [ComponentSpec("vae", AutoencoderKL)]
        return components

    @property
    def inputs(self) -> List[InputParam]:
250
        inputs = [InputParam(self._image_input_name), InputParam("generator")]
251
252
253
254
255
256
257
258
259
260
261
        return inputs

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
        return [
            OutputParam(
                self._image_latents_output_name,
                type_hint=torch.Tensor,
                description="The latents representing the reference image",
            )
        ]
262
263
264
265

    @torch.no_grad()
    def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
        block_state = self.get_block_state(state)
266
        image = getattr(block_state, self._image_input_name)
267

268
269
270
271
272
273
274
275
276
277
278
279
        if image is None:
            setattr(block_state, self._image_latents_output_name, None)
        else:
            device = components._execution_device
            dtype = components.vae.dtype
            image = image.to(device=device, dtype=dtype)

            # Encode image into latents
            image_latents = encode_vae_image(
                image=image, vae=components.vae, generator=block_state.generator, sample_mode=self.sample_mode
            )
            setattr(block_state, self._image_latents_output_name, image_latents)
280
281
282
283
284
285

        self.set_block_state(state, block_state)

        return components, state


286
class FluxTextEncoderStep(ModularPipelineBlocks):
287
288
289
290
    model_name = "flux"

    @property
    def description(self) -> str:
291
        return "Text Encoder step that generate text_embeddings to guide the image generation"
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306

    @property
    def expected_components(self) -> List[ComponentSpec]:
        return [
            ComponentSpec("text_encoder", CLIPTextModel),
            ComponentSpec("tokenizer", CLIPTokenizer),
            ComponentSpec("text_encoder_2", T5EncoderModel),
            ComponentSpec("tokenizer_2", T5TokenizerFast),
        ]

    @property
    def inputs(self) -> List[InputParam]:
        return [
            InputParam("prompt"),
            InputParam("prompt_2"),
307
            InputParam("max_sequence_length", type_hint=int, default=512, required=False),
308
309
310
311
312
313
314
315
            InputParam("joint_attention_kwargs"),
        ]

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
        return [
            OutputParam(
                "prompt_embeds",
316
                kwargs_type="denoiser_input_fields",
317
318
319
320
321
                type_hint=torch.Tensor,
                description="text embeddings used to guide the image generation",
            ),
            OutputParam(
                "pooled_prompt_embeds",
322
                kwargs_type="denoiser_input_fields",
323
324
325
326
327
328
329
330
331
332
333
334
335
                type_hint=torch.Tensor,
                description="pooled text embeddings used to guide the image generation",
            ),
        ]

    @staticmethod
    def check_inputs(block_state):
        for prompt in [block_state.prompt, block_state.prompt_2]:
            if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
                raise ValueError(f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}")

    @staticmethod
    def _get_t5_prompt_embeds(
336
        components, prompt: Union[str, List[str]], max_sequence_length: int, device: torch.device
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    ):
        dtype = components.text_encoder_2.dtype
        prompt = [prompt] if isinstance(prompt, str) else prompt

        if isinstance(components, TextualInversionLoaderMixin):
            prompt = components.maybe_convert_prompt(prompt, components.tokenizer_2)

        text_inputs = components.tokenizer_2(
            prompt,
            padding="max_length",
            max_length=max_sequence_length,
            truncation=True,
            return_length=False,
            return_overflowing_tokens=False,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids

        untruncated_ids = components.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
            removed_text = components.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
            logger.warning(
                "The following part of your input was truncated because `max_sequence_length` is set to "
                f" {max_sequence_length} tokens: {removed_text}"
            )

        prompt_embeds = components.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
        return prompt_embeds

    @staticmethod
368
    def _get_clip_prompt_embeds(components, prompt: Union[str, List[str]], device: torch.device):
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        prompt = [prompt] if isinstance(prompt, str) else prompt

        if isinstance(components, TextualInversionLoaderMixin):
            prompt = components.maybe_convert_prompt(prompt, components.tokenizer)

        text_inputs = components.tokenizer(
            prompt,
            padding="max_length",
            max_length=components.tokenizer.model_max_length,
            truncation=True,
            return_overflowing_tokens=False,
            return_length=False,
            return_tensors="pt",
        )

        text_input_ids = text_inputs.input_ids
        tokenizer_max_length = components.tokenizer.model_max_length
        untruncated_ids = components.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
            removed_text = components.tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
            logger.warning(
                "The following part of your input was truncated because CLIP can only handle sequences up to"
                f" {tokenizer_max_length} tokens: {removed_text}"
            )
        prompt_embeds = components.text_encoder(text_input_ids.to(device), output_hidden_states=False)

        # Use pooled output of CLIPTextModel
        prompt_embeds = prompt_embeds.pooler_output
        prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device)

        return prompt_embeds

    @staticmethod
    def encode_prompt(
        components,
        prompt: Union[str, List[str]],
        prompt_2: Union[str, List[str]],
        device: Optional[torch.device] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
        max_sequence_length: int = 512,
        lora_scale: Optional[float] = None,
    ):
        device = device or components._execution_device

        # set lora scale so that monkey patched LoRA
        # function of text encoder can correctly access it
        if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin):
            components._lora_scale = lora_scale

            # dynamically adjust the LoRA scale
            if components.text_encoder is not None and USE_PEFT_BACKEND:
                scale_lora_layers(components.text_encoder, lora_scale)
            if components.text_encoder_2 is not None and USE_PEFT_BACKEND:
                scale_lora_layers(components.text_encoder_2, lora_scale)

        prompt = [prompt] if isinstance(prompt, str) else prompt

        if prompt_embeds is None:
            prompt_2 = prompt_2 or prompt
            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2

            # We only use the pooled prompt output from the CLIPTextModel
            pooled_prompt_embeds = FluxTextEncoderStep._get_clip_prompt_embeds(
                components,
                prompt=prompt,
                device=device,
            )
            prompt_embeds = FluxTextEncoderStep._get_t5_prompt_embeds(
                components,
                prompt=prompt_2,
                max_sequence_length=max_sequence_length,
                device=device,
            )

        if components.text_encoder is not None:
            if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
                # Retrieve the original scale by scaling back the LoRA layers
                unscale_lora_layers(components.text_encoder, lora_scale)

        if components.text_encoder_2 is not None:
            if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
                # Retrieve the original scale by scaling back the LoRA layers
                unscale_lora_layers(components.text_encoder_2, lora_scale)

454
        return prompt_embeds, pooled_prompt_embeds
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469

    @torch.no_grad()
    def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
        # Get inputs and intermediates
        block_state = self.get_block_state(state)
        self.check_inputs(block_state)

        block_state.device = components._execution_device

        # Encode input prompt
        block_state.text_encoder_lora_scale = (
            block_state.joint_attention_kwargs.get("scale", None)
            if block_state.joint_attention_kwargs is not None
            else None
        )
470
        block_state.prompt_embeds, block_state.pooled_prompt_embeds = self.encode_prompt(
471
472
473
474
475
476
            components,
            prompt=block_state.prompt,
            prompt_2=None,
            prompt_embeds=None,
            pooled_prompt_embeds=None,
            device=block_state.device,
477
            max_sequence_length=block_state.max_sequence_length,
478
479
480
481
482
483
            lora_scale=block_state.text_encoder_lora_scale,
        )

        # Add outputs
        self.set_block_state(state, block_state)
        return components, state