embeddings.py 8.26 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0/src/diffusers/models/embeddings.py
import torch

from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed, CogVideoXPatchEmbed
import torch.distributed
from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.model_executor.layers import xFuserLayerBaseWrapper
from xfuser.model_executor.layers import xFuserLayerWrappersRegister
from xfuser.logger import init_logger

logger = init_logger(__name__)


@xFuserLayerWrappersRegister.register(PatchEmbed)
class xFuserPatchEmbedWrapper(xFuserLayerBaseWrapper):
    def __init__(
        self,
        patch_embedding: PatchEmbed,
    ):
        super().__init__(
            module=patch_embedding,
        )
        self.module: PatchEmbed
        self.pos_embed = None

    def forward(self, latent):
        height = (
            get_runtime_state().input_config.height
            // get_runtime_state().vae_scale_factor
        )
        width = latent.shape[-1]
        if not get_runtime_state().patch_mode:
            if getattr(self.module, "pos_embed_max_size", None) is not None:
                pass
            else:
                height, width = (
                    height // self.module.patch_size,
                    width // self.module.patch_size,
                )
        else:
            if getattr(self.module, "pos_embed_max_size", None) is not None:
                pass
            else:
                height, width = (
                    height // self.module.patch_size,
                    width // self.module.patch_size,
                )

        latent = self.module.proj(latent)
        if self.module.flatten:
            latent = latent.flatten(2).transpose(1, 2)  # BCHW -> BNC
        if self.module.layer_norm:
            # TODO: NOT SURE whether compatible with norm
            latent = self.module.norm(latent)

        # [2, 4096 / c, 1152]

        if self.module.pos_embed is None:
            return latent.to(latent.dtype)

        # Interpolate positional embeddings if needed.
        # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)

        # TODO: There might be a more faster way to generate a smaller pos_embed
        if getattr(self.module, "pos_embed_max_size", None):
            pos_embed = self.module.cropped_pos_embed(height, width)
        else:
            if self.module.height != height or self.module.width != width:
                pos_embed = get_2d_sincos_pos_embed(
                    embed_dim=self.module.pos_embed.shape[-1],
                    grid_size=(height, width),
                    base_size=self.module.base_size,
                    interpolation_scale=self.module.interpolation_scale,
                )
                pos_embed = torch.from_numpy(pos_embed)
                self.module.pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
                self.module.height = height
                self.module.width = width
                pos_embed = self.module.pos_embed
            else:
                pos_embed = self.module.pos_embed
        b, c, h = pos_embed.shape

        if get_runtime_state().patch_mode:
            start, end = get_runtime_state().pp_patches_token_start_end_idx_global[
                get_runtime_state().pipeline_patch_idx
            ]
            pos_embed = pos_embed[
                :,
                start:end,
                :,
            ]
        else:
            pos_embed_list = [
                pos_embed[
                    :,
                    get_runtime_state()
                    .pp_patches_token_start_end_idx_global[i][0] : get_runtime_state()
                    .pp_patches_token_start_end_idx_global[i][1],
                    :,
                ]
                for i in range(get_runtime_state().num_pipeline_patch)
            ]
            pos_embed = torch.cat(pos_embed_list, dim=1)

        return (latent + pos_embed).to(latent.dtype)


@xFuserLayerWrappersRegister.register(CogVideoXPatchEmbed)
class xFuserCogVideoXPatchEmbedWrapper(xFuserLayerBaseWrapper):
    def __init__(
        self,
        patch_embedding: CogVideoXPatchEmbed,
    ):
        super().__init__(
            module=patch_embedding,
        )
        self.module: CogVideoXPatchEmbed

    def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
        r"""
        Args:
            text_embeds (`torch.Tensor`):
                Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
            image_embeds (`torch.Tensor`):
                Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
        """
        # height is the height of a batch on a GPU, sum_height is the total height of the video
        sum_height = (
            get_runtime_state().input_config.height
            // get_runtime_state().vae_scale_factor_spatial
        )
        text_embeds = self.text_proj(text_embeds)

        batch_size, num_frames, channels, height, width = image_embeds.shape
        
        if self.patch_size_t is None:
            image_embeds = image_embeds.reshape(-1, channels, height, width)
            image_embeds = self.proj(image_embeds)
            image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
            image_embeds = image_embeds.flatten(3).transpose(2, 3)  # [batch, num_frames, height x width, channels]
            image_embeds = image_embeds.flatten(1, 2)  # [batch, num_frames x height x width, channels]
        else:
            p = self.patch_size
            p_t = self.patch_size_t

            image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
            image_embeds = image_embeds.reshape(
                batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
            )
            image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
            image_embeds = self.proj(image_embeds)
 
        embeds = torch.cat(
            [text_embeds, image_embeds], dim=1
        ).contiguous()  # [batch, seq_length + num_frames x height x width, channels]
           
        if self.use_positional_embeddings or self.use_learned_positional_embeddings:
            if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != sum_height):
                raise ValueError(
                    "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
                    "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
                )

            pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1

            if (
                self.sample_height != sum_height
                or self.sample_width != width
                or self.sample_frames != pre_time_compression_frames
            ):
                pos_embedding = self._get_positional_embeddings(sum_height, width, pre_time_compression_frames)
                pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
            else:
                pos_embedding = self.pos_embedding

            # extract the image part of the positional embedding
            pos_embedding = pos_embedding[:, self.max_text_seq_length :]

            # slice the positional embedding
            post_patch_height = sum_height // self.patch_size
            post_patch_width = width // self.patch_size
            post_time_compression_frames = (pre_time_compression_frames - 1) // self.temporal_compression_ratio + 1

            pos_embed_list = [
                pos_embedding[
                    :,
                    post_patch_height * post_patch_width * i + get_runtime_state().pp_patches_token_start_end_idx_global[0][0]: 
                    post_patch_height * post_patch_width * i + get_runtime_state().pp_patches_token_start_end_idx_global[0][1],
                    :,
                ]
                for i in range(post_time_compression_frames)
            ]
            pos_embedding = torch.cat(pos_embed_list, dim=1)

            embeds[:, self.max_text_seq_length :] += pos_embedding

        return embeds