test_pipeline_lumina2.py 3.2 KB
Newer Older
Le Zhuo's avatar
Le Zhuo committed
1
2
3
import unittest

import torch
4
from transformers import AutoTokenizer, Gemma2Config, Gemma2Model
Le Zhuo's avatar
Le Zhuo committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

from diffusers import (
    AutoencoderKL,
    FlowMatchEulerDiscreteScheduler,
    Lumina2Text2ImgPipeline,
    Lumina2Transformer2DModel,
)

from ..test_pipelines_common import PipelineTesterMixin


class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
    pipeline_class = Lumina2Text2ImgPipeline
    params = frozenset(
        [
            "prompt",
            "height",
            "width",
            "guidance_scale",
            "negative_prompt",
            "prompt_embeds",
            "negative_prompt_embeds",
        ]
    )
    batch_params = frozenset(["prompt", "negative_prompt"])
    required_optional_params = frozenset(
        [
            "num_inference_steps",
            "generator",
            "latents",
            "return_dict",
            "callback_on_step_end",
            "callback_on_step_end_tensor_inputs",
        ]
    )

    supports_dduf = False
    test_xformers_attention = False
    test_layerwise_casting = True

    def get_dummy_components(self):
        torch.manual_seed(0)
        transformer = Lumina2Transformer2DModel(
            sample_size=4,
            patch_size=2,
            in_channels=4,
            hidden_size=8,
            num_layers=2,
            num_attention_heads=1,
            num_kv_heads=1,
            multiple_of=16,
            ffn_dim_multiplier=None,
            norm_eps=1e-5,
            scaling_factor=1.0,
            axes_dim_rope=[4, 2, 2],
            cap_feat_dim=8,
        )

        torch.manual_seed(0)
        vae = AutoencoderKL(
            sample_size=32,
            in_channels=3,
            out_channels=3,
            block_out_channels=(4,),
            layers_per_block=1,
            latent_channels=4,
            norm_num_groups=1,
            use_quant_conv=False,
            use_post_quant_conv=False,
            shift_factor=0.0609,
            scaling_factor=1.5035,
        )

        scheduler = FlowMatchEulerDiscreteScheduler()
        tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")

        torch.manual_seed(0)
82
83
        config = Gemma2Config(
            head_dim=4,
Le Zhuo's avatar
Le Zhuo committed
84
            hidden_size=8,
85
86
            intermediate_size=8,
            num_attention_heads=2,
Le Zhuo's avatar
Le Zhuo committed
87
            num_hidden_layers=2,
88
89
            num_key_value_heads=2,
            sliding_window=2,
Le Zhuo's avatar
Le Zhuo committed
90
        )
91
        text_encoder = Gemma2Model(config)
Le Zhuo's avatar
Le Zhuo committed
92
93

        components = {
94
            "transformer": transformer,
Le Zhuo's avatar
Le Zhuo committed
95
96
            "vae": vae.eval(),
            "scheduler": scheduler,
97
            "text_encoder": text_encoder,
Le Zhuo's avatar
Le Zhuo committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            "tokenizer": tokenizer,
        }
        return components

    def get_dummy_inputs(self, device, seed=0):
        if str(device).startswith("mps"):
            generator = torch.manual_seed(seed)
        else:
            generator = torch.Generator(device="cpu").manual_seed(seed)

        inputs = {
            "prompt": "A painting of a squirrel eating a burger",
            "generator": generator,
            "num_inference_steps": 2,
            "guidance_scale": 5.0,
            "height": 32,
            "width": 32,
            "output_type": "np",
        }
        return inputs