test_pipeline_flux2.py 6.61 KB
Newer Older
Sayak Paul's avatar
Sayak Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import unittest

import numpy as np
import torch
from transformers import AutoProcessor, Mistral3Config, Mistral3ForConditionalGeneration

from diffusers import (
    AutoencoderKLFlux2,
    FlowMatchEulerDiscreteScheduler,
    Flux2Pipeline,
    Flux2Transformer2DModel,
)

from ...testing_utils import (
    torch_device,
)
from ..test_pipelines_common import (
    PipelineTesterMixin,
    check_qkv_fused_layers_exist,
)


class Flux2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
    pipeline_class = Flux2Pipeline
    params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
    batch_params = frozenset(["prompt"])

    test_xformers_attention = False
    test_layerwise_casting = True
    test_group_offloading = True

    supports_dduf = False

    def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
        torch.manual_seed(0)
        transformer = Flux2Transformer2DModel(
            patch_size=1,
            in_channels=4,
            num_layers=num_layers,
            num_single_layers=num_single_layers,
            attention_head_dim=16,
            num_attention_heads=2,
            joint_attention_dim=16,
            timestep_guidance_channels=256,  # Hardcoded in original code
            axes_dims_rope=[4, 4, 4, 4],
        )

        config = Mistral3Config(
            text_config={
                "model_type": "mistral",
                "vocab_size": 32000,
                "hidden_size": 16,
                "intermediate_size": 37,
                "max_position_embeddings": 512,
                "num_attention_heads": 4,
                "num_hidden_layers": 1,
                "num_key_value_heads": 2,
                "rms_norm_eps": 1e-05,
                "rope_theta": 1000000000.0,
                "sliding_window": None,
                "bos_token_id": 2,
                "eos_token_id": 3,
                "pad_token_id": 4,
            },
            vision_config={
                "model_type": "pixtral",
                "hidden_size": 16,
                "num_hidden_layers": 1,
                "num_attention_heads": 4,
                "intermediate_size": 37,
                "image_size": 30,
                "patch_size": 6,
                "num_channels": 3,
            },
            bos_token_id=2,
            eos_token_id=3,
            pad_token_id=4,
            model_dtype="mistral3",
            image_seq_length=4,
            vision_feature_layer=-1,
            image_token_index=1,
        )
        torch.manual_seed(0)
        text_encoder = Mistral3ForConditionalGeneration(config)
        tokenizer = AutoProcessor.from_pretrained(
            "hf-internal-testing/Mistral-Small-3.1-24B-Instruct-2503-only-processor"
        )

        torch.manual_seed(0)
        vae = AutoencoderKLFlux2(
            sample_size=32,
            in_channels=3,
            out_channels=3,
            down_block_types=("DownEncoderBlock2D",),
            up_block_types=("UpDecoderBlock2D",),
            block_out_channels=(4,),
            layers_per_block=1,
            latent_channels=1,
            norm_num_groups=1,
            use_quant_conv=False,
            use_post_quant_conv=False,
        )

        scheduler = FlowMatchEulerDiscreteScheduler()

        return {
            "scheduler": scheduler,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
            "transformer": transformer,
            "vae": vae,
        }

    def get_dummy_inputs(self, device, seed=0):
        if str(device).startswith("mps"):
            generator = torch.manual_seed(seed)
        else:
            generator = torch.Generator(device="cpu").manual_seed(seed)

        inputs = {
            "prompt": "a dog is dancing",
            "generator": generator,
            "num_inference_steps": 2,
            "guidance_scale": 5.0,
            "height": 8,
            "width": 8,
            "max_sequence_length": 8,
            "output_type": "np",
            "text_encoder_out_layers": (1,),
        }
        return inputs

    def test_fused_qkv_projections(self):
        device = "cpu"  # ensure determinism for the device-dependent torch.Generator
        components = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(device)
        pipe.set_progress_bar_config(disable=None)

        inputs = self.get_dummy_inputs(device)
        image = pipe(**inputs).images
        original_image_slice = image[0, -3:, -3:, -1]

        # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
        # to the pipeline level.
        pipe.transformer.fuse_qkv_projections()
        self.assertTrue(
            check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
            ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
        )

        inputs = self.get_dummy_inputs(device)
        image = pipe(**inputs).images
        image_slice_fused = image[0, -3:, -3:, -1]

        pipe.transformer.unfuse_qkv_projections()
        inputs = self.get_dummy_inputs(device)
        image = pipe(**inputs).images
        image_slice_disabled = image[0, -3:, -3:, -1]

        self.assertTrue(
            np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
            ("Fusion of QKV projections shouldn't affect the outputs."),
        )
        self.assertTrue(
            np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
            ("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
        )
        self.assertTrue(
            np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
            ("Original outputs should match when fused QKV projections are disabled."),
        )

    def test_flux_image_output_shape(self):
        pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
        inputs = self.get_dummy_inputs(torch_device)

        height_width_pairs = [(32, 32), (72, 57)]
        for height, width in height_width_pairs:
            expected_height = height - height % (pipe.vae_scale_factor * 2)
            expected_width = width - width % (pipe.vae_scale_factor * 2)

            inputs.update({"height": height, "width": width})
            image = pipe(**inputs).images[0]
            output_height, output_width, _ = image.shape
            self.assertEqual(
                (output_height, output_width),
                (expected_height, expected_width),
                f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
            )