cfg_parallel.py 8.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""
Base pipeline class for Diffusion models with shared CFG functionality.
"""

from abc import ABCMeta
from typing import Any

import torch

from vllm_omni.diffusion.distributed.parallel_state import (
    get_cfg_group,
    get_classifier_free_guidance_rank,
    get_classifier_free_guidance_world_size,
)


class CFGParallelMixin(metaclass=ABCMeta):
    """
    Base Mixin class for Diffusion pipelines providing shared CFG methods.

    All pipelines should inherit from this class to reuse
    classifier-free guidance logic.
    """

    def predict_noise_maybe_with_cfg(
        self,
        do_true_cfg: bool,
        true_cfg_scale: float,
        positive_kwargs: dict[str, Any],
        negative_kwargs: dict[str, Any] | None,
        cfg_normalize: bool = True,
        output_slice: int | None = None,
    ) -> torch.Tensor | None:
        """
        Predict noise with optional classifier-free guidance.

        Args:
            do_true_cfg: Whether to apply CFG
            true_cfg_scale: CFG scale factor
            positive_kwargs: Kwargs for positive/conditional prediction
            negative_kwargs: Kwargs for negative/unconditional prediction
            cfg_normalize: Whether to normalize CFG output (default: True)
            output_slice: If set, slice output to [:, :output_slice] for image editing

        Returns:
            Predicted noise tensor (only valid on rank 0 in CFG parallel mode)
        """
        if do_true_cfg:
            # Automatically detect CFG parallel configuration
            cfg_parallel_ready = get_classifier_free_guidance_world_size() > 1

            if cfg_parallel_ready:
                # Enable CFG-parallel: rank0 computes positive, rank1 computes negative.
                cfg_group = get_cfg_group()
                cfg_rank = get_classifier_free_guidance_rank()

                if cfg_rank == 0:
                    local_pred = self.predict_noise(**positive_kwargs)
                else:
                    local_pred = self.predict_noise(**negative_kwargs)

                # Slice output for image editing pipelines (remove condition latents)
                if output_slice is not None:
                    local_pred = local_pred[:, :output_slice]

                gathered = cfg_group.all_gather(local_pred, separate_tensors=True)

                if cfg_rank == 0:
                    noise_pred = gathered[0]
                    neg_noise_pred = gathered[1]
                    noise_pred = self.combine_cfg_noise(noise_pred, neg_noise_pred, true_cfg_scale, cfg_normalize)
                    return noise_pred
                else:
                    return None
            else:
                # Sequential CFG: compute both positive and negative
                positive_noise_pred = self.predict_noise(**positive_kwargs)
                negative_noise_pred = self.predict_noise(**negative_kwargs)

                # Slice output for image editing pipelines
                if output_slice is not None:
                    positive_noise_pred = positive_noise_pred[:, :output_slice]
                    negative_noise_pred = negative_noise_pred[:, :output_slice]

                noise_pred = self.combine_cfg_noise(
                    positive_noise_pred, negative_noise_pred, true_cfg_scale, cfg_normalize
                )
                return noise_pred
        else:
            # No CFG: only compute positive/conditional prediction
            pred = self.predict_noise(**positive_kwargs)
            if output_slice is not None:
                pred = pred[:, :output_slice]
            return pred

    def cfg_normalize_function(self, noise_pred: torch.Tensor, comb_pred: torch.Tensor) -> torch.Tensor:
        """
        Normalize the combined noise prediction.

        Args:
            noise_pred: positive noise prediction
            comb_pred: combined noise prediction after CFG

        Returns:
            Normalized noise prediction tensor
        """
        cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
        noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
        noise_pred = comb_pred * (cond_norm / noise_norm)
        return noise_pred

    def combine_cfg_noise(
        self, noise_pred: torch.Tensor, neg_noise_pred: torch.Tensor, true_cfg_scale: float, cfg_normalize: bool = False
    ) -> torch.Tensor:
        """
        Combine conditional and unconditional noise predictions with CFG.

        Args:
            noise_pred: Conditional noise prediction
            neg_noise_pred: Unconditional noise prediction
            true_cfg_scale: CFG scale factor
            cfg_normalize: Whether to normalize the combined prediction (default: False)

        Returns:
            Combined noise prediction tensor
        """
        comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)

        if cfg_normalize:
            noise_pred = self.cfg_normalize_function(noise_pred, comb_pred)
        else:
            noise_pred = comb_pred

        return noise_pred

    def predict_noise(self, *args: Any, **kwargs: Any) -> torch.Tensor:
        """
        Forward pass through transformer to predict noise.

        Subclasses should override this if they need custom behavior,
        but the default implementation calls self.transformer.
        """
        return self.transformer(*args, **kwargs)[0]

    def diffuse(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        """
        Diffusion loop with optional classifier-free guidance.

        Subclasses MUST implement this method to define the complete
        diffusion/denoising loop for their specific model.

        Typical implementation pattern:
        ```python
        def diffuse(self, latents, timesteps, prompt_embeds, negative_embeds, ...):
            for t in timesteps:
                # Prepare kwargs for positive and negative predictions
                positive_kwargs = {...}
                negative_kwargs = {...}

                # Predict noise with automatic CFG handling
                noise_pred = self.predict_noise_maybe_with_cfg(
                    do_true_cfg=True,
                    true_cfg_scale=self.guidance_scale,
                    positive_kwargs=positive_kwargs,
                    negative_kwargs=negative_kwargs,
                )

                # Step scheduler with automatic CFG sync
                latents = self.scheduler_step_maybe_with_cfg(
                    noise_pred, t, latents, do_true_cfg=True
                )

            return latents
        ```
        """
        raise NotImplementedError("Subclasses must implement diffuse")

    def scheduler_step(self, noise_pred: torch.Tensor, t: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
        """
        Step the scheduler.

        Args:
            noise_pred: Predicted noise
            t: Current timestep
            latents: Current latents

        Returns:
            Updated latents after scheduler step
        """
        return self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

    def scheduler_step_maybe_with_cfg(
        self, noise_pred: torch.Tensor, t: torch.Tensor, latents: torch.Tensor, do_true_cfg: bool
    ) -> torch.Tensor:
        """
        Step the scheduler with (maybe) automatic CFG parallel synchronization.

        In CFG parallel mode, only rank 0 computes the scheduler step,
        then broadcasts the result to other ranks.

        Args:
            noise_pred: Predicted noise (only valid on rank 0 in CFG parallel)
            t: Current timestep
            latents: Current latents
            do_true_cfg: Whether CFG is enabled

        Returns:
            Updated latents (synchronized across all CFG ranks)
        """
        # Automatically detect CFG parallel configuration
        cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1

        if cfg_parallel_ready:
            cfg_group = get_cfg_group()
            cfg_rank = get_classifier_free_guidance_rank()

            # Only rank 0 computes the scheduler step
            if cfg_rank == 0:
                latents = self.scheduler_step(noise_pred, t, latents)

            # Broadcast the updated latents to all ranks
            latents = latents.contiguous()
            cfg_group.broadcast(latents, src=0)
        else:
            # No CFG parallel: directly compute scheduler step
            latents = self.scheduler_step(noise_pred, t, latents)

        return latents