"mmdet3d/configs/_base_/datasets/scannet_3d.py" did not exist on "aea26ac7cd544678db5bff157a658352449c29ac"
test_pipeline_flux_redux.py 3.19 KB
Newer Older
Aryan's avatar
Aryan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gc
import unittest

import numpy as np
import pytest
import torch

from diffusers import FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
    numpy_cosine_similarity_distance,
    require_big_gpu_with_torch_cuda,
    slow,
    torch_device,
)


@slow
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
class FluxReduxSlowTests(unittest.TestCase):
    pipeline_class = FluxPriorReduxPipeline
    repo_id = "YiYiXu/yiyi-redux"  # update to "black-forest-labs/FLUX.1-Redux-dev" once PR is merged
    base_pipeline_class = FluxPipeline
    base_repo_id = "black-forest-labs/FLUX.1-schnell"

    def setUp(self):
        super().setUp()
        gc.collect()
        torch.cuda.empty_cache()

    def tearDown(self):
        super().tearDown()
        gc.collect()
        torch.cuda.empty_cache()

    def get_inputs(self, device, seed=0):
        init_image = load_image(
            "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png"
        )
        return {"image": init_image}

    def get_base_pipeline_inputs(self, device, seed=0):
        if str(device).startswith("mps"):
            generator = torch.manual_seed(seed)
        else:
            generator = torch.Generator(device="cpu").manual_seed(seed)

        return {
            "num_inference_steps": 2,
            "guidance_scale": 2.0,
            "output_type": "np",
            "generator": generator,
        }

    def test_flux_redux_inference(self):
        pipe_redux = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
        pipe_base = self.base_pipeline_class.from_pretrained(
            self.base_repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
        )
        pipe_redux.to(torch_device)
        pipe_base.enable_model_cpu_offload()

        inputs = self.get_inputs(torch_device)
        base_pipeline_inputs = self.get_base_pipeline_inputs(torch_device)

        redux_pipeline_output = pipe_redux(**inputs)
        image = pipe_base(**base_pipeline_inputs, **redux_pipeline_output).images[0]

        image_slice = image[0, :10, :10]
        expected_slice = np.array(
            [
                0.30078125,
                0.37890625,
                0.46875,
                0.28125,
                0.36914062,
                0.47851562,
                0.28515625,
                0.375,
                0.4765625,
                0.28125,
                0.375,
                0.48046875,
                0.27929688,
                0.37695312,
                0.47851562,
                0.27734375,
                0.38085938,
                0.4765625,
                0.2734375,
                0.38085938,
                0.47265625,
                0.27539062,
                0.37890625,
                0.47265625,
                0.27734375,
                0.37695312,
                0.47070312,
                0.27929688,
                0.37890625,
                0.47460938,
            ],
            dtype=np.float32,
        )
        max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())

        assert max_diff < 1e-4