model.py 17.1 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

from lightx2v.attentions import attention
root's avatar
root committed
12
from loguru import logger
gushiqiao's avatar
gushiqiao committed
13
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8
wangshankun's avatar
wangshankun committed
14
15
16
from einops import rearrange
from torch import Tensor
from transformers import CLIPVisionModel
helloyongyang's avatar
helloyongyang committed
17
18
19


__all__ = [
Dongz's avatar
Dongz committed
20
21
22
    "XLMRobertaCLIP",
    "clip_xlm_roberta_vit_h_14",
    "CLIPModel",
helloyongyang's avatar
helloyongyang committed
23
24
25
26
27
28
29
30
31
32
]


def pos_interpolate(pos, seq_len):
    if pos.size(1) == seq_len:
        return pos
    else:
        src_grid = int(math.sqrt(pos.size(1)))
        tar_grid = int(math.sqrt(seq_len))
        n = pos.size(1) - src_grid * src_grid
Dongz's avatar
Dongz committed
33
34
35
36
37
38
39
        return torch.cat(
            [
                pos[:, :n],
                F.interpolate(pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2), size=(tar_grid, tar_grid), mode="bicubic", align_corners=False).flatten(2).transpose(1, 2),
            ],
            dim=1,
        )
helloyongyang's avatar
helloyongyang committed
40
41
42
43
44
45
46
47
48
49
50
51
52


class QuickGELU(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(1.702 * x)


class LayerNorm(nn.LayerNorm):
    def forward(self, x):
        return super().forward(x.float()).type_as(x)


class SelfAttention(nn.Module):
gushiqiao's avatar
gushiqiao committed
53
    def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0, quantized=False, quant_scheme=None, dtype=None):
helloyongyang's avatar
helloyongyang committed
54
55
56
57
58
59
60
61
62
63
        assert dim % num_heads == 0
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.causal = causal
        self.attn_dropout = attn_dropout
        self.proj_dropout = proj_dropout

        # layers
64
65
        if quantized:
            if quant_scheme == "int8":
gushiqiao's avatar
gushiqiao committed
66
                linear_cls = VllmQuantLinearInt8
67
            elif quant_scheme == "fp8":
gushiqiao's avatar
gushiqiao committed
68
69
70
                linear_cls = VllmQuantLinearFp8
            elif quant_scheme == "int8-torchao":
                linear_cls = TorchaoQuantLinearInt8
71
72
73
        else:
            linear_cls = nn.Linear

gushiqiao's avatar
gushiqiao committed
74
75
        self.to_qkv = linear_cls(dim, dim * 3, dtype=dtype)
        self.proj = linear_cls(dim, dim, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
76
77
78
79
80
81
82
83
84
85
86

    def forward(self, x):
        """
        x:   [B, L, C].
        """
        b, s, c, n, d = *x.size(), self.num_heads, self.head_dim

        # compute query, key, value
        q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)

        # compute attention
Dongz's avatar
Dongz committed
87
        x = attention(q=q, k=k, v=v, attention_type="torch_sdpa")
helloyongyang's avatar
helloyongyang committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        x = x.reshape(b, s, c)

        # output
        x = self.proj(x)
        x = F.dropout(x, self.proj_dropout, self.training)
        return x


class SwiGLU(nn.Module):
    def __init__(self, dim, mid_dim):
        super().__init__()
        self.dim = dim
        self.mid_dim = mid_dim
        # layers
        self.fc1 = nn.Linear(dim, mid_dim)
        self.fc2 = nn.Linear(dim, mid_dim)
        self.fc3 = nn.Linear(mid_dim, dim)

    def forward(self, x):
        x = F.silu(self.fc1(x)) * self.fc2(x)
        x = self.fc3(x)
        return x


class AttentionBlock(nn.Module):
gushiqiao's avatar
gushiqiao committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    def __init__(
        self,
        dim,
        mlp_ratio,
        num_heads,
        post_norm=False,
        causal=False,
        activation="quick_gelu",
        attn_dropout=0.0,
        proj_dropout=0.0,
        norm_eps=1e-5,
        quantized=False,
        quant_scheme=None,
        dtype=torch.float16,
    ):
Dongz's avatar
Dongz committed
128
        assert activation in ["quick_gelu", "gelu", "swi_glu"]
helloyongyang's avatar
helloyongyang committed
129
130
131
132
133
134
135
136
137
        super().__init__()
        self.dim = dim
        self.mlp_ratio = mlp_ratio
        self.num_heads = num_heads
        self.post_norm = post_norm
        self.causal = causal
        self.norm_eps = norm_eps

        # layers
138
139
        if quantized:
            if quant_scheme == "int8":
gushiqiao's avatar
gushiqiao committed
140
                linear_cls = VllmQuantLinearInt8
141
            elif quant_scheme == "fp8":
gushiqiao's avatar
gushiqiao committed
142
143
144
                linear_cls = VllmQuantLinearFp8
            elif quant_scheme == "int8-torchao":
                linear_cls = TorchaoQuantLinearInt8
145
146
147
        else:
            linear_cls = nn.Linear

gushiqiao's avatar
gushiqiao committed
148
149
150
        self.norm1 = LayerNorm(dim, eps=norm_eps, dtype=dtype)
        self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout, quantized, quant_scheme, dtype)
        self.norm2 = LayerNorm(dim, eps=norm_eps, dtype=dtype)
Dongz's avatar
Dongz committed
151
        if activation == "swi_glu":
gushiqiao's avatar
gushiqiao committed
152
            self.mlp = SwiGLU(dim, int(dim * mlp_ratio), dtype=dtype)
helloyongyang's avatar
helloyongyang committed
153
        else:
gushiqiao's avatar
gushiqiao committed
154
155
156
157
158
159
            self.mlp = nn.Sequential(
                linear_cls(dim, int(dim * mlp_ratio), dtype=dtype),
                QuickGELU() if activation == "quick_gelu" else nn.GELU(),
                linear_cls(int(dim * mlp_ratio), dim, dtype=dtype),
                nn.Dropout(proj_dropout),
            )
helloyongyang's avatar
helloyongyang committed
160
161
162
163
164
165
166
167
168
169
170
171

    def forward(self, x):
        if self.post_norm:
            x = x + self.norm1(self.attn(x))
            x = x + self.norm2(self.mlp(x))
        else:
            x = x + self.attn(self.norm1(x))
            x = x + self.mlp(self.norm2(x))
        return x


class AttentionPool(nn.Module):
gushiqiao's avatar
gushiqiao committed
172
    def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5, dtype=torch.float16):
helloyongyang's avatar
helloyongyang committed
173
174
175
176
177
178
179
180
181
182
183
184
        assert dim % num_heads == 0
        super().__init__()
        self.dim = dim
        self.mlp_ratio = mlp_ratio
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.proj_dropout = proj_dropout
        self.norm_eps = norm_eps

        # layers
        gain = 1.0 / math.sqrt(dim)
        self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
gushiqiao's avatar
gushiqiao committed
185
186
187
188
189
190
191
        self.to_q = nn.Linear(dim, dim, dtype=dtype)
        self.to_kv = nn.Linear(dim, dim * 2, dtype=dtype)
        self.proj = nn.Linear(dim, dim, dtype=dtype)
        self.norm = LayerNorm(dim, eps=norm_eps, dtype=dtype)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio), dtype=dtype), QuickGELU() if activation == "quick_gelu" else nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim, dtype=dtype), nn.Dropout(proj_dropout)
        )
helloyongyang's avatar
helloyongyang committed
192
193
194
195
196
197
198
199
200
201
202
203

    def forward(self, x):
        """
        x:  [B, L, C].
        """
        b, s, c, n, d = *x.size(), self.num_heads, self.head_dim

        # compute query, key, value
        q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
        k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)

        # compute attention
Dongz's avatar
Dongz committed
204
        x = attention(q=q, k=k, v=v, attention_type="torch_sdpa")
helloyongyang's avatar
helloyongyang committed
205
206
207
208
209
210
211
212
213
214
215
216
        x = x.reshape(b, 1, c)

        # output
        x = self.proj(x)
        x = F.dropout(x, self.proj_dropout, self.training)

        # mlp
        x = x + self.mlp(self.norm(x))
        return x[:, 0]


class VisionTransformer(nn.Module):
Dongz's avatar
Dongz committed
217
218
    def __init__(
        self,
gushiqiao's avatar
gushiqiao committed
219
        dtype=torch.float16,
Dongz's avatar
Dongz committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        image_size=224,
        patch_size=16,
        dim=768,
        mlp_ratio=4,
        out_dim=512,
        num_heads=12,
        num_layers=12,
        pool_type="token",
        pre_norm=True,
        post_norm=False,
        activation="quick_gelu",
        attn_dropout=0.0,
        proj_dropout=0.0,
        embedding_dropout=0.0,
        norm_eps=1e-5,
235
236
        quantized=False,
        quant_scheme=None,
Dongz's avatar
Dongz committed
237
    ):
helloyongyang's avatar
helloyongyang committed
238
        if image_size % patch_size != 0:
root's avatar
root committed
239
            logger.info("[WARNING] image_size is not divisible by patch_size", flush=True)
Dongz's avatar
Dongz committed
240
        assert pool_type in ("token", "token_fc", "attn_pool")
helloyongyang's avatar
helloyongyang committed
241
242
243
244
        out_dim = out_dim or dim
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
Dongz's avatar
Dongz committed
245
        self.num_patches = (image_size // patch_size) ** 2
helloyongyang's avatar
helloyongyang committed
246
247
248
249
250
251
252
253
254
255
256
        self.dim = dim
        self.mlp_ratio = mlp_ratio
        self.out_dim = out_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.pool_type = pool_type
        self.post_norm = post_norm
        self.norm_eps = norm_eps

        # embeddings
        gain = 1.0 / math.sqrt(dim)
gushiqiao's avatar
gushiqiao committed
257
        self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm, dtype=dtype)
Dongz's avatar
Dongz committed
258
        if pool_type in ("token", "token_fc"):
gushiqiao's avatar
gushiqiao committed
259
260
            self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim, dtype=dtype))
        self.pos_embedding = nn.Parameter(gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim, dtype=dtype))
helloyongyang's avatar
helloyongyang committed
261
262
263
        self.dropout = nn.Dropout(embedding_dropout)

        # transformer
gushiqiao's avatar
gushiqiao committed
264
        self.pre_norm = LayerNorm(dim, eps=norm_eps, dtype=dtype) if pre_norm else None
265
        self.transformer = nn.Sequential(
gushiqiao's avatar
gushiqiao committed
266
            *[AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps, quantized, quant_scheme, dtype) for _ in range(num_layers)]
267
        )
gushiqiao's avatar
gushiqiao committed
268
        self.post_norm = LayerNorm(dim, eps=norm_eps, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
269
270

        # head
Dongz's avatar
Dongz committed
271
        if pool_type == "token":
gushiqiao's avatar
gushiqiao committed
272
            self.head = nn.Parameter(gain * torch.randn(dim, out_dim, dtype=dtype))
Dongz's avatar
Dongz committed
273
        elif pool_type == "token_fc":
gushiqiao's avatar
gushiqiao committed
274
            self.head = nn.Linear(dim, out_dim, dtype=dtype)
Dongz's avatar
Dongz committed
275
        elif pool_type == "attn_pool":
gushiqiao's avatar
gushiqiao committed
276
            self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
277
278
279
280
281
282

    def forward(self, x, interpolation=False, use_31_block=False):
        b = x.size(0)

        # embeddings
        x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
Dongz's avatar
Dongz committed
283
        if self.pool_type in ("token", "token_fc"):
helloyongyang's avatar
helloyongyang committed
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
            x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
        if interpolation:
            e = pos_interpolate(self.pos_embedding, x.size(1))
        else:
            e = self.pos_embedding
        x = self.dropout(x + e)
        if self.pre_norm is not None:
            x = self.pre_norm(x)

        # transformer
        if use_31_block:
            x = self.transformer[:-1](x)
            return x
        else:
            x = self.transformer(x)
            return x


class XLMRobertaCLIP(nn.Module):
Dongz's avatar
Dongz committed
303
304
    def __init__(
        self,
gushiqiao's avatar
gushiqiao committed
305
        dtype=torch.float16,
Dongz's avatar
Dongz committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        embed_dim=1024,
        image_size=224,
        patch_size=14,
        vision_dim=1280,
        vision_mlp_ratio=4,
        vision_heads=16,
        vision_layers=32,
        vision_pool="token",
        vision_pre_norm=True,
        vision_post_norm=False,
        activation="gelu",
        vocab_size=250002,
        max_text_len=514,
        type_size=1,
        pad_id=1,
        attn_dropout=0.0,
        proj_dropout=0.0,
        embedding_dropout=0.0,
        norm_eps=1e-5,
325
326
        quantized=False,
        quant_scheme=None,
Dongz's avatar
Dongz committed
327
    ):
helloyongyang's avatar
helloyongyang committed
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        super().__init__()
        self.embed_dim = embed_dim
        self.image_size = image_size
        self.patch_size = patch_size
        self.vision_dim = vision_dim
        self.vision_mlp_ratio = vision_mlp_ratio
        self.vision_heads = vision_heads
        self.vision_layers = vision_layers
        self.vision_pre_norm = vision_pre_norm
        self.vision_post_norm = vision_post_norm
        self.activation = activation
        self.vocab_size = vocab_size
        self.max_text_len = max_text_len
        self.type_size = type_size
        self.pad_id = pad_id
        self.norm_eps = norm_eps

        # models
        self.visual = VisionTransformer(
gushiqiao's avatar
gushiqiao committed
347
            dtype=dtype,
helloyongyang's avatar
helloyongyang committed
348
349
350
351
352
353
354
355
356
357
358
359
360
361
            image_size=image_size,
            patch_size=patch_size,
            dim=vision_dim,
            mlp_ratio=vision_mlp_ratio,
            out_dim=embed_dim,
            num_heads=vision_heads,
            num_layers=vision_layers,
            pool_type=vision_pool,
            pre_norm=vision_pre_norm,
            post_norm=vision_post_norm,
            activation=activation,
            attn_dropout=attn_dropout,
            proj_dropout=proj_dropout,
            embedding_dropout=embedding_dropout,
Dongz's avatar
Dongz committed
362
            norm_eps=norm_eps,
363
364
            quantized=quantized,
            quant_scheme=quant_scheme,
Dongz's avatar
Dongz committed
365
        )
helloyongyang's avatar
helloyongyang committed
366
367
368
        self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))


Dongz's avatar
Dongz committed
369
def _clip(pretrained=False, pretrained_name=None, model_cls=XLMRobertaCLIP, return_transforms=False, return_tokenizer=False, tokenizer_padding="eos", dtype=torch.float32, device="cpu", **kwargs):
helloyongyang's avatar
helloyongyang committed
370
371
    # init a model on device
    with torch.device(device):
gushiqiao's avatar
gushiqiao committed
372
        model = model_cls(dtype=dtype, **kwargs)
helloyongyang's avatar
helloyongyang committed
373

gushiqiao's avatar
gushiqiao committed
374
    model = model.to(device=device)
helloyongyang's avatar
helloyongyang committed
375

gushiqiao's avatar
gushiqiao committed
376
    output = (model,)
helloyongyang's avatar
helloyongyang committed
377
378
379
    # init transforms
    if return_transforms:
        # mean and std
Dongz's avatar
Dongz committed
380
        if "siglip" in pretrained_name.lower():
helloyongyang's avatar
helloyongyang committed
381
382
383
384
385
386
            mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
        else:
            mean = [0.48145466, 0.4578275, 0.40821073]
            std = [0.26862954, 0.26130258, 0.27577711]

        # transforms
Dongz's avatar
Dongz committed
387
        transforms = T.Compose([T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=mean, std=std)])
helloyongyang's avatar
helloyongyang committed
388
389
390
391
        output += (transforms,)
    return output[0] if len(output) == 1 else output


Dongz's avatar
Dongz committed
392
def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs):
helloyongyang's avatar
helloyongyang committed
393
394
395
396
397
398
399
400
    cfg = dict(
        embed_dim=1024,
        image_size=224,
        patch_size=14,
        vision_dim=1280,
        vision_mlp_ratio=4,
        vision_heads=16,
        vision_layers=32,
Dongz's avatar
Dongz committed
401
402
        vision_pool="token",
        activation="gelu",
helloyongyang's avatar
helloyongyang committed
403
404
405
406
407
408
        vocab_size=250002,
        max_text_len=514,
        type_size=1,
        pad_id=1,
        attn_dropout=0.0,
        proj_dropout=0.0,
Dongz's avatar
Dongz committed
409
410
        embedding_dropout=0.0,
    )
helloyongyang's avatar
helloyongyang committed
411
412
413
414
415
    cfg.update(**kwargs)
    return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)


class CLIPModel:
416
    def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme):
helloyongyang's avatar
helloyongyang committed
417
418
        self.dtype = dtype
        self.device = device
419
420
421
422
423
424
        self.quantized = clip_quantized
        if self.quantized:
            self.checkpoint_path = clip_quantized_ckpt
        else:
            self.checkpoint_path = checkpoint_path

helloyongyang's avatar
helloyongyang committed
425
        # init model
426
427
428
        self.model, self.transforms = clip_xlm_roberta_vit_h_14(
            pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme
        )
helloyongyang's avatar
helloyongyang committed
429
        self.model = self.model.eval().requires_grad_(False)
430
431
432
433
434
        weight_dict = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True)
        keys = list(weight_dict.keys())
        for key in keys:
            if "textual" in key:
                weight_dict.pop(key)
gushiqiao's avatar
gushiqiao committed
435
436

        logger.info(f"Start Loading weights from {self.checkpoint_path}")
437
        self.model.load_state_dict(weight_dict)
gushiqiao's avatar
gushiqiao committed
438
        logger.info(f"End Loading weights from {self.checkpoint_path}")
helloyongyang's avatar
helloyongyang committed
439

gushiqiao's avatar
gushiqiao committed
440
    def visual(self, videos, args):
441
        if hasattr(args, "cpu_offload") and args.cpu_offload:
gushiqiao's avatar
gushiqiao committed
442
            self.to_cuda()
helloyongyang's avatar
helloyongyang committed
443
444
        # preprocess
        size = (self.model.image_size,) * 2
Dongz's avatar
Dongz committed
445
        videos = torch.cat([F.interpolate(u.transpose(0, 1), size=size, mode="bicubic", align_corners=False) for u in videos])
helloyongyang's avatar
helloyongyang committed
446
447
448
        videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))

        # forward
Dongz's avatar
Dongz committed
449
        with torch.amp.autocast("cuda", dtype=self.dtype):
helloyongyang's avatar
helloyongyang committed
450
            out = self.model.visual(videos, use_31_block=True)
gushiqiao's avatar
gushiqiao committed
451

452
        if hasattr(args, "cpu_offload") and args.cpu_offload:
gushiqiao's avatar
gushiqiao committed
453
454
455
456
457
458
459
460
            self.to_cpu()
        return out

    def to_cuda(self):
        self.model = self.model.cuda()

    def to_cpu(self):
        self.model = self.model.cpu()
wangshankun's avatar
wangshankun committed
461
462
463


class WanVideoIPHandler:
wangshankun's avatar
wangshankun committed
464
    def __init__(self, model_name, repo_or_path, require_grad=False, mode="eval", device="cuda", dtype=torch.float16):
wangshankun's avatar
wangshankun committed
465
466
        # image_processor = CLIPImageProcessor.from_pretrained(
        #     repo_or_path, subfolder='image_processor')
wangshankun's avatar
wangshankun committed
467
        """720P-I2V-diffusers config is
wangshankun's avatar
wangshankun committed
468
469
470
471
472
473
474
475
476
477
            "size": {
                "shortest_edge": 224
            }
        and 480P-I2V-diffusers config is
          "size": {
            "height": 224,
            "width": 224
        }
        but Wan2.1 official use no_crop resize by default
        so I don't use CLIPImageProcessor
wangshankun's avatar
wangshankun committed
478
        """
wangshankun's avatar
wangshankun committed
479
        image_encoder = CLIPVisionModel.from_pretrained(repo_or_path, torch_dtype=dtype)
wangshankun's avatar
wangshankun committed
480
        logger.info(f"Using image encoder {model_name} from {repo_or_path}")
wangshankun's avatar
wangshankun committed
481
        image_encoder.requires_grad_(require_grad)
wangshankun's avatar
wangshankun committed
482
        if mode == "eval":
wangshankun's avatar
wangshankun committed
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
            image_encoder.eval()
        else:
            image_encoder.train()
        self.dtype = dtype
        self.device = device
        self.image_encoder = image_encoder.to(device=device, dtype=dtype)
        self.size = (224, 224)
        mean = [0.48145466, 0.4578275, 0.40821073]
        std = [0.26862954, 0.26130258, 0.27577711]
        self.normalize = T.Normalize(mean=mean, std=std)
        # self.image_processor = image_processor

    def encode(
        self,
        img_tensor: Tensor,
    ):
        if img_tensor.ndim == 5:  # B C T H W
            # img_tensor = img_tensor[:, :, 0]
            img_tensor = rearrange(img_tensor, "B C 1 H W -> B C H W")
wangshankun's avatar
wangshankun committed
502
503
        img_tensor = torch.clamp(img_tensor.float() * 0.5 + 0.5, min=0.0, max=1.0).to(self.device)
        img_tensor = F.interpolate(img_tensor, size=self.size, mode="bicubic", align_corners=False)
wangshankun's avatar
wangshankun committed
504
        img_tensor = self.normalize(img_tensor).to(self.dtype)
wangshankun's avatar
wangshankun committed
505
506
507
508

        image_embeds = self.image_encoder(pixel_values=img_tensor, output_hidden_states=True)

        return image_embeds.hidden_states[-1]