cfg_parallel.py 6.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CFG Parallel Mixin for Qwen Image series
Shared by
- QwenImagePipeline
- QwenImageEditPipeline
- QwenImageEditPlusPipeline
- QwenImageLayeredPipeline
"""

import logging
from typing import Any

import torch

from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.parallel_state import get_classifier_free_guidance_world_size

logger = logging.getLogger(__name__)


class QwenImageCFGParallelMixin(CFGParallelMixin):
    """
    Base Mixin class for Qwen Image pipelines providing shared CFG methods.
    """

    def diffuse(
        self,
        prompt_embeds: torch.Tensor,
        prompt_embeds_mask: torch.Tensor,
        negative_prompt_embeds: torch.Tensor,
        negative_prompt_embeds_mask: torch.Tensor,
        latents: torch.Tensor,
        img_shapes: torch.Tensor,
        txt_seq_lens: torch.Tensor,
        negative_txt_seq_lens: torch.Tensor,
        timesteps: torch.Tensor,
        do_true_cfg: bool,
        guidance: torch.Tensor,
        true_cfg_scale: float,
        image_latents: torch.Tensor | None = None,
        cfg_normalize: bool = True,
        additional_transformer_kwargs: dict[str, Any] | None = None,
    ) -> torch.Tensor:
        """
        Diffusion loop with optional classifier-free guidance.

        Args:
            prompt_embeds: Positive prompt embeddings
            prompt_embeds_mask: Mask for positive prompt
            negative_prompt_embeds: Negative prompt embeddings
            negative_prompt_embeds_mask: Mask for negative prompt
            latents: Noise latents to denoise
            img_shapes: Image shape information
            txt_seq_lens: Text sequence lengths for positive prompts
            negative_txt_seq_lens: Text sequence lengths for negative prompts
            timesteps: Diffusion timesteps
            do_true_cfg: Whether to apply CFG
            guidance: Guidance scale tensor
            true_cfg_scale: CFG scale factor
            image_latents: Conditional image latents for editing (default: None)
            cfg_normalize: Whether to normalize CFG output (default: True)
            additional_transformer_kwargs: Extra kwargs to pass to transformer (default: None)

        Returns:
            Denoised latents
        """
        self.scheduler.set_begin_index(0)
        self.transformer.do_true_cfg = do_true_cfg
        additional_transformer_kwargs = additional_transformer_kwargs or {}

        for i, t in enumerate(timesteps):
            if self.interrupt:
                continue
            self._current_timestep = t

            # Broadcast timestep to match batch size
            timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype)

            # Concatenate image latents with noise latents if available (for editing pipelines)
            latent_model_input = latents
            if image_latents is not None:
                latent_model_input = torch.cat([latents, image_latents], dim=1)

            positive_kwargs = {
                "hidden_states": latent_model_input,
                "timestep": timestep / 1000,
                "guidance": guidance,
                "encoder_hidden_states_mask": prompt_embeds_mask,
                "encoder_hidden_states": prompt_embeds,
                "img_shapes": img_shapes,
                "txt_seq_lens": txt_seq_lens,
                **additional_transformer_kwargs,
            }
            if do_true_cfg:
                negative_kwargs = {
                    "hidden_states": latent_model_input,
                    "timestep": timestep / 1000,
                    "guidance": guidance,
                    "encoder_hidden_states_mask": negative_prompt_embeds_mask,
                    "encoder_hidden_states": negative_prompt_embeds,
                    "img_shapes": img_shapes,
                    "txt_seq_lens": negative_txt_seq_lens,
                    **additional_transformer_kwargs,
                }
            else:
                negative_kwargs = None

            # For editing pipelines, we need to slice the output to remove condition latents
            output_slice = latents.size(1) if image_latents is not None else None

            # Predict noise with automatic CFG parallel handling
            noise_pred = self.predict_noise_maybe_with_cfg(
                do_true_cfg,
                true_cfg_scale,
                positive_kwargs,
                negative_kwargs,
                cfg_normalize,
                output_slice,
            )

            # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync
            latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)

        return latents

    def check_cfg_parallel_validity(self, true_cfg_scale: float, has_neg_prompt: bool):
        """
        Validate whether CFG parallel is properly configured for the current generation request.

        When CFG parallel is enabled (cfg_parallel_world_size > 1), this method verifies that the necessary
        conditions are met for correct parallel execution. If validation fails, a warning is
        logged to help identify configuration issues.

        Args:
            true_cfg_scale: The classifier-free guidance scale value. Must be > 1 for CFG to
                have an effect.
            has_neg_prompt: Whether negative prompts or negative prompt embeddings are provided.
                Required for CFG to perform unconditional prediction.

        Returns:
            True if CFG parallel is disabled or all validation checks pass, False otherwise.

        Note:
            When CFG parallel is disabled (world_size == 1), this method always returns True
            as no parallel-specific validation is needed.
        """
        if get_classifier_free_guidance_world_size() == 1:
            return True

        if true_cfg_scale <= 1:
            logger.warning("CFG parallel is NOT working correctly when true_cfg_scale <= 1.")
            return False

        if not has_neg_prompt:
            logger.warning(
                "CFG parallel is NOT working correctly when there is no negative prompt or negative prompt embeddings."
            )
            return False
        return True