test_data_preprocess.py 3.73 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import unittest

import torch
from transformers import AutoTokenizer, T5EncoderModel

from fastvideo.models.hunyuan.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D


class TestAutoencoderKLCausal3D(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        """
        setUpClass is called once, before any test is run.
        We can set environment variables or load heavy resources here.
        """
        os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"

        # Load tokenizer/model that can be reused across all tests
        cls.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
        cls.text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")

    def setUp(self):
        """
        setUp is called before each test method to prepare fresh state.
        """
        self.batch_size = 1
        self.init_time_len = 9
        self.init_height = 16
        self.init_width = 16
        self.latent_channels = 4
        self.spatial_compression_ratio = 8
        self.time_compression_ratio = 4

        # Model initialization config
        self.init_dict = {
            "in_channels":
            3,
            "out_channels":
            3,
            "latent_channels":
            self.latent_channels,
            "down_block_types": (
                "DownEncoderBlockCausal3D",
                "DownEncoderBlockCausal3D",
                "DownEncoderBlockCausal3D",
                "DownEncoderBlockCausal3D",
            ),
            "up_block_types": (
                "UpDecoderBlockCausal3D",
                "UpDecoderBlockCausal3D",
                "UpDecoderBlockCausal3D",
                "UpDecoderBlockCausal3D",
            ),
            "block_out_channels": (8, 8, 8, 8),
            "layers_per_block":
            1,
            "act_fn":
            "silu",
            "norm_num_groups":
            4,
            "scaling_factor":
            0.476986,
            "spatial_compression_ratio":
            self.spatial_compression_ratio,
            "time_compression_ratio":
            self.time_compression_ratio,
            "mid_block_add_attention":
            True,
        }

        # Instantiate the model
        self.model = AutoencoderKLCausal3D(**self.init_dict)

        # Create a random input tensor
        self.input_tensor = torch.rand(self.batch_size, 3, self.init_time_len, self.init_height, self.init_width)

    def test_encode_shape(self):
        """
        Check that the shape of the encoded output matches expectations.
        """
        vae_encoder_output = self.model.encode(self.input_tensor)

        # The distribution from the VAE has a .sample() method
        # so we verify the shape of that sample.
        sample_shape = vae_encoder_output["latent_dist"].sample().shape

        # We expect shape: [batch_size, latent_channels,
        #                   (init_time_len // time_compression_ratio) + 1,
        #                   init_height // spatial_compression_ratio,
        #                   init_width // spatial_compression_ratio]
        expected_shape = (
            self.batch_size,
            self.latent_channels,
            (self.init_time_len // self.time_compression_ratio) + 1,
            self.init_height // self.spatial_compression_ratio,
            self.init_width // self.spatial_compression_ratio,
        )

        # (Optional) Print them if you like, or just rely on assertions:
        print(f"sample_shape: {sample_shape}")
        print(f"expected_shape: {expected_shape}")

        self.assertEqual(
            sample_shape,
            expected_shape,
            f"Encoded sample shape {sample_shape} does not match {expected_shape}.",
        )


if __name__ == "__main__":
    unittest.main()