model.py 16.8 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

from lightx2v.attentions import attention
root's avatar
root committed
12
from loguru import logger
13
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8, QuantLinearFp8
wangshankun's avatar
wangshankun committed
14
15
16
from einops import rearrange
from torch import Tensor
from transformers import CLIPVisionModel
helloyongyang's avatar
helloyongyang committed
17
18
19


__all__ = [
Dongz's avatar
Dongz committed
20
21
22
    "XLMRobertaCLIP",
    "clip_xlm_roberta_vit_h_14",
    "CLIPModel",
helloyongyang's avatar
helloyongyang committed
23
24
25
26
27
28
29
30
31
32
]


def pos_interpolate(pos, seq_len):
    if pos.size(1) == seq_len:
        return pos
    else:
        src_grid = int(math.sqrt(pos.size(1)))
        tar_grid = int(math.sqrt(seq_len))
        n = pos.size(1) - src_grid * src_grid
Dongz's avatar
Dongz committed
33
34
35
36
37
38
39
        return torch.cat(
            [
                pos[:, :n],
                F.interpolate(pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2), size=(tar_grid, tar_grid), mode="bicubic", align_corners=False).flatten(2).transpose(1, 2),
            ],
            dim=1,
        )
helloyongyang's avatar
helloyongyang committed
40
41
42
43
44
45
46
47
48
49
50
51
52


class QuickGELU(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(1.702 * x)


class LayerNorm(nn.LayerNorm):
    def forward(self, x):
        return super().forward(x.float()).type_as(x)


class SelfAttention(nn.Module):
gushiqiao's avatar
gushiqiao committed
53
    def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0, quantized=False, quant_scheme=None, dtype=None):
helloyongyang's avatar
helloyongyang committed
54
55
56
57
58
59
60
61
62
63
        assert dim % num_heads == 0
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.causal = causal
        self.attn_dropout = attn_dropout
        self.proj_dropout = proj_dropout

        # layers
64
65
66
        if quantized:
            if quant_scheme == "int8":
                linear_cls = QuantLinearInt8
67
68
            elif quant_scheme == "fp8":
                linear_cls = QuantLinearFp8
69
70
71
        else:
            linear_cls = nn.Linear

gushiqiao's avatar
gushiqiao committed
72
73
        self.to_qkv = linear_cls(dim, dim * 3, dtype=dtype)
        self.proj = linear_cls(dim, dim, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
74
75
76
77
78
79
80
81
82
83
84

    def forward(self, x):
        """
        x:   [B, L, C].
        """
        b, s, c, n, d = *x.size(), self.num_heads, self.head_dim

        # compute query, key, value
        q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)

        # compute attention
Dongz's avatar
Dongz committed
85
        x = attention(q=q, k=k, v=v, attention_type="torch_sdpa")
helloyongyang's avatar
helloyongyang committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        x = x.reshape(b, s, c)

        # output
        x = self.proj(x)
        x = F.dropout(x, self.proj_dropout, self.training)
        return x


class SwiGLU(nn.Module):
    def __init__(self, dim, mid_dim):
        super().__init__()
        self.dim = dim
        self.mid_dim = mid_dim
        # layers
        self.fc1 = nn.Linear(dim, mid_dim)
        self.fc2 = nn.Linear(dim, mid_dim)
        self.fc3 = nn.Linear(mid_dim, dim)

    def forward(self, x):
        x = F.silu(self.fc1(x)) * self.fc2(x)
        x = self.fc3(x)
        return x


class AttentionBlock(nn.Module):
gushiqiao's avatar
gushiqiao committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    def __init__(
        self,
        dim,
        mlp_ratio,
        num_heads,
        post_norm=False,
        causal=False,
        activation="quick_gelu",
        attn_dropout=0.0,
        proj_dropout=0.0,
        norm_eps=1e-5,
        quantized=False,
        quant_scheme=None,
        dtype=torch.float16,
    ):
Dongz's avatar
Dongz committed
126
        assert activation in ["quick_gelu", "gelu", "swi_glu"]
helloyongyang's avatar
helloyongyang committed
127
128
129
130
131
132
133
134
135
        super().__init__()
        self.dim = dim
        self.mlp_ratio = mlp_ratio
        self.num_heads = num_heads
        self.post_norm = post_norm
        self.causal = causal
        self.norm_eps = norm_eps

        # layers
136
137
138
        if quantized:
            if quant_scheme == "int8":
                linear_cls = QuantLinearInt8
139
140
            elif quant_scheme == "fp8":
                linear_cls = QuantLinearFp8
141
142
143
        else:
            linear_cls = nn.Linear

gushiqiao's avatar
gushiqiao committed
144
145
146
        self.norm1 = LayerNorm(dim, eps=norm_eps, dtype=dtype)
        self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout, quantized, quant_scheme, dtype)
        self.norm2 = LayerNorm(dim, eps=norm_eps, dtype=dtype)
Dongz's avatar
Dongz committed
147
        if activation == "swi_glu":
gushiqiao's avatar
gushiqiao committed
148
            self.mlp = SwiGLU(dim, int(dim * mlp_ratio), dtype=dtype)
helloyongyang's avatar
helloyongyang committed
149
        else:
gushiqiao's avatar
gushiqiao committed
150
151
152
153
154
155
            self.mlp = nn.Sequential(
                linear_cls(dim, int(dim * mlp_ratio), dtype=dtype),
                QuickGELU() if activation == "quick_gelu" else nn.GELU(),
                linear_cls(int(dim * mlp_ratio), dim, dtype=dtype),
                nn.Dropout(proj_dropout),
            )
helloyongyang's avatar
helloyongyang committed
156
157
158
159
160
161
162
163
164
165
166
167

    def forward(self, x):
        if self.post_norm:
            x = x + self.norm1(self.attn(x))
            x = x + self.norm2(self.mlp(x))
        else:
            x = x + self.attn(self.norm1(x))
            x = x + self.mlp(self.norm2(x))
        return x


class AttentionPool(nn.Module):
gushiqiao's avatar
gushiqiao committed
168
    def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5, dtype=torch.float16):
helloyongyang's avatar
helloyongyang committed
169
170
171
172
173
174
175
176
177
178
179
180
        assert dim % num_heads == 0
        super().__init__()
        self.dim = dim
        self.mlp_ratio = mlp_ratio
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.proj_dropout = proj_dropout
        self.norm_eps = norm_eps

        # layers
        gain = 1.0 / math.sqrt(dim)
        self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
gushiqiao's avatar
gushiqiao committed
181
182
183
184
185
186
187
        self.to_q = nn.Linear(dim, dim, dtype=dtype)
        self.to_kv = nn.Linear(dim, dim * 2, dtype=dtype)
        self.proj = nn.Linear(dim, dim, dtype=dtype)
        self.norm = LayerNorm(dim, eps=norm_eps, dtype=dtype)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio), dtype=dtype), QuickGELU() if activation == "quick_gelu" else nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim, dtype=dtype), nn.Dropout(proj_dropout)
        )
helloyongyang's avatar
helloyongyang committed
188
189
190
191
192
193
194
195
196
197
198
199

    def forward(self, x):
        """
        x:  [B, L, C].
        """
        b, s, c, n, d = *x.size(), self.num_heads, self.head_dim

        # compute query, key, value
        q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
        k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)

        # compute attention
Dongz's avatar
Dongz committed
200
        x = attention(q=q, k=k, v=v, attention_type="torch_sdpa")
helloyongyang's avatar
helloyongyang committed
201
202
203
204
205
206
207
208
209
210
211
212
        x = x.reshape(b, 1, c)

        # output
        x = self.proj(x)
        x = F.dropout(x, self.proj_dropout, self.training)

        # mlp
        x = x + self.mlp(self.norm(x))
        return x[:, 0]


class VisionTransformer(nn.Module):
Dongz's avatar
Dongz committed
213
214
    def __init__(
        self,
gushiqiao's avatar
gushiqiao committed
215
        dtype=torch.float16,
Dongz's avatar
Dongz committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        image_size=224,
        patch_size=16,
        dim=768,
        mlp_ratio=4,
        out_dim=512,
        num_heads=12,
        num_layers=12,
        pool_type="token",
        pre_norm=True,
        post_norm=False,
        activation="quick_gelu",
        attn_dropout=0.0,
        proj_dropout=0.0,
        embedding_dropout=0.0,
        norm_eps=1e-5,
231
232
        quantized=False,
        quant_scheme=None,
Dongz's avatar
Dongz committed
233
    ):
helloyongyang's avatar
helloyongyang committed
234
        if image_size % patch_size != 0:
root's avatar
root committed
235
            logger.info("[WARNING] image_size is not divisible by patch_size", flush=True)
Dongz's avatar
Dongz committed
236
        assert pool_type in ("token", "token_fc", "attn_pool")
helloyongyang's avatar
helloyongyang committed
237
238
239
240
        out_dim = out_dim or dim
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
Dongz's avatar
Dongz committed
241
        self.num_patches = (image_size // patch_size) ** 2
helloyongyang's avatar
helloyongyang committed
242
243
244
245
246
247
248
249
250
251
252
        self.dim = dim
        self.mlp_ratio = mlp_ratio
        self.out_dim = out_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.pool_type = pool_type
        self.post_norm = post_norm
        self.norm_eps = norm_eps

        # embeddings
        gain = 1.0 / math.sqrt(dim)
gushiqiao's avatar
gushiqiao committed
253
        self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm, dtype=dtype)
Dongz's avatar
Dongz committed
254
        if pool_type in ("token", "token_fc"):
gushiqiao's avatar
gushiqiao committed
255
256
            self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim, dtype=dtype))
        self.pos_embedding = nn.Parameter(gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim, dtype=dtype))
helloyongyang's avatar
helloyongyang committed
257
258
259
        self.dropout = nn.Dropout(embedding_dropout)

        # transformer
gushiqiao's avatar
gushiqiao committed
260
        self.pre_norm = LayerNorm(dim, eps=norm_eps, dtype=dtype) if pre_norm else None
261
        self.transformer = nn.Sequential(
gushiqiao's avatar
gushiqiao committed
262
            *[AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps, quantized, quant_scheme, dtype) for _ in range(num_layers)]
263
        )
gushiqiao's avatar
gushiqiao committed
264
        self.post_norm = LayerNorm(dim, eps=norm_eps, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
265
266

        # head
Dongz's avatar
Dongz committed
267
        if pool_type == "token":
gushiqiao's avatar
gushiqiao committed
268
            self.head = nn.Parameter(gain * torch.randn(dim, out_dim, dtype=dtype))
Dongz's avatar
Dongz committed
269
        elif pool_type == "token_fc":
gushiqiao's avatar
gushiqiao committed
270
            self.head = nn.Linear(dim, out_dim, dtype=dtype)
Dongz's avatar
Dongz committed
271
        elif pool_type == "attn_pool":
gushiqiao's avatar
gushiqiao committed
272
            self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps, dtype=dtype)
helloyongyang's avatar
helloyongyang committed
273
274
275
276
277
278

    def forward(self, x, interpolation=False, use_31_block=False):
        b = x.size(0)

        # embeddings
        x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
Dongz's avatar
Dongz committed
279
        if self.pool_type in ("token", "token_fc"):
helloyongyang's avatar
helloyongyang committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
            x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
        if interpolation:
            e = pos_interpolate(self.pos_embedding, x.size(1))
        else:
            e = self.pos_embedding
        x = self.dropout(x + e)
        if self.pre_norm is not None:
            x = self.pre_norm(x)

        # transformer
        if use_31_block:
            x = self.transformer[:-1](x)
            return x
        else:
            x = self.transformer(x)
            return x


class XLMRobertaCLIP(nn.Module):
Dongz's avatar
Dongz committed
299
300
    def __init__(
        self,
gushiqiao's avatar
gushiqiao committed
301
        dtype=torch.float16,
Dongz's avatar
Dongz committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        embed_dim=1024,
        image_size=224,
        patch_size=14,
        vision_dim=1280,
        vision_mlp_ratio=4,
        vision_heads=16,
        vision_layers=32,
        vision_pool="token",
        vision_pre_norm=True,
        vision_post_norm=False,
        activation="gelu",
        vocab_size=250002,
        max_text_len=514,
        type_size=1,
        pad_id=1,
        attn_dropout=0.0,
        proj_dropout=0.0,
        embedding_dropout=0.0,
        norm_eps=1e-5,
321
322
        quantized=False,
        quant_scheme=None,
Dongz's avatar
Dongz committed
323
    ):
helloyongyang's avatar
helloyongyang committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        super().__init__()
        self.embed_dim = embed_dim
        self.image_size = image_size
        self.patch_size = patch_size
        self.vision_dim = vision_dim
        self.vision_mlp_ratio = vision_mlp_ratio
        self.vision_heads = vision_heads
        self.vision_layers = vision_layers
        self.vision_pre_norm = vision_pre_norm
        self.vision_post_norm = vision_post_norm
        self.activation = activation
        self.vocab_size = vocab_size
        self.max_text_len = max_text_len
        self.type_size = type_size
        self.pad_id = pad_id
        self.norm_eps = norm_eps

        # models
        self.visual = VisionTransformer(
gushiqiao's avatar
gushiqiao committed
343
            dtype=dtype,
helloyongyang's avatar
helloyongyang committed
344
345
346
347
348
349
350
351
352
353
354
355
356
357
            image_size=image_size,
            patch_size=patch_size,
            dim=vision_dim,
            mlp_ratio=vision_mlp_ratio,
            out_dim=embed_dim,
            num_heads=vision_heads,
            num_layers=vision_layers,
            pool_type=vision_pool,
            pre_norm=vision_pre_norm,
            post_norm=vision_post_norm,
            activation=activation,
            attn_dropout=attn_dropout,
            proj_dropout=proj_dropout,
            embedding_dropout=embedding_dropout,
Dongz's avatar
Dongz committed
358
            norm_eps=norm_eps,
359
360
            quantized=quantized,
            quant_scheme=quant_scheme,
Dongz's avatar
Dongz committed
361
        )
helloyongyang's avatar
helloyongyang committed
362
363
364
        self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))


Dongz's avatar
Dongz committed
365
def _clip(pretrained=False, pretrained_name=None, model_cls=XLMRobertaCLIP, return_transforms=False, return_tokenizer=False, tokenizer_padding="eos", dtype=torch.float32, device="cpu", **kwargs):
helloyongyang's avatar
helloyongyang committed
366
367
    # init a model on device
    with torch.device(device):
gushiqiao's avatar
gushiqiao committed
368
        model = model_cls(dtype=dtype, **kwargs)
helloyongyang's avatar
helloyongyang committed
369

gushiqiao's avatar
gushiqiao committed
370
    model = model.to(device=device)
helloyongyang's avatar
helloyongyang committed
371

gushiqiao's avatar
gushiqiao committed
372
    output = (model,)
helloyongyang's avatar
helloyongyang committed
373
374
375
    # init transforms
    if return_transforms:
        # mean and std
Dongz's avatar
Dongz committed
376
        if "siglip" in pretrained_name.lower():
helloyongyang's avatar
helloyongyang committed
377
378
379
380
381
382
            mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
        else:
            mean = [0.48145466, 0.4578275, 0.40821073]
            std = [0.26862954, 0.26130258, 0.27577711]

        # transforms
Dongz's avatar
Dongz committed
383
        transforms = T.Compose([T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=mean, std=std)])
helloyongyang's avatar
helloyongyang committed
384
385
386
387
        output += (transforms,)
    return output[0] if len(output) == 1 else output


Dongz's avatar
Dongz committed
388
def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs):
helloyongyang's avatar
helloyongyang committed
389
390
391
392
393
394
395
396
    cfg = dict(
        embed_dim=1024,
        image_size=224,
        patch_size=14,
        vision_dim=1280,
        vision_mlp_ratio=4,
        vision_heads=16,
        vision_layers=32,
Dongz's avatar
Dongz committed
397
398
        vision_pool="token",
        activation="gelu",
helloyongyang's avatar
helloyongyang committed
399
400
401
402
403
404
        vocab_size=250002,
        max_text_len=514,
        type_size=1,
        pad_id=1,
        attn_dropout=0.0,
        proj_dropout=0.0,
Dongz's avatar
Dongz committed
405
406
        embedding_dropout=0.0,
    )
helloyongyang's avatar
helloyongyang committed
407
408
409
410
411
    cfg.update(**kwargs)
    return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)


class CLIPModel:
412
    def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme):
helloyongyang's avatar
helloyongyang committed
413
414
        self.dtype = dtype
        self.device = device
415
416
417
418
419
420
        self.quantized = clip_quantized
        if self.quantized:
            self.checkpoint_path = clip_quantized_ckpt
        else:
            self.checkpoint_path = checkpoint_path

helloyongyang's avatar
helloyongyang committed
421
        # init model
422
423
424
        self.model, self.transforms = clip_xlm_roberta_vit_h_14(
            pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme
        )
helloyongyang's avatar
helloyongyang committed
425
        self.model = self.model.eval().requires_grad_(False)
426
427
428
429
430
        weight_dict = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True)
        keys = list(weight_dict.keys())
        for key in keys:
            if "textual" in key:
                weight_dict.pop(key)
gushiqiao's avatar
gushiqiao committed
431
432

        logger.info(f"Start Loading weights from {self.checkpoint_path}")
433
        self.model.load_state_dict(weight_dict)
gushiqiao's avatar
gushiqiao committed
434
        logger.info(f"End Loading weights from {self.checkpoint_path}")
helloyongyang's avatar
helloyongyang committed
435

gushiqiao's avatar
gushiqiao committed
436
437
438
    def visual(self, videos, args):
        if args.cpu_offload:
            self.to_cuda()
helloyongyang's avatar
helloyongyang committed
439
440
        # preprocess
        size = (self.model.image_size,) * 2
Dongz's avatar
Dongz committed
441
        videos = torch.cat([F.interpolate(u.transpose(0, 1), size=size, mode="bicubic", align_corners=False) for u in videos])
helloyongyang's avatar
helloyongyang committed
442
443
444
        videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))

        # forward
Dongz's avatar
Dongz committed
445
        with torch.amp.autocast("cuda", dtype=self.dtype):
helloyongyang's avatar
helloyongyang committed
446
            out = self.model.visual(videos, use_31_block=True)
gushiqiao's avatar
gushiqiao committed
447
448
449
450
451
452
453
454
455
456

        if args.cpu_offload:
            self.to_cpu()
        return out

    def to_cuda(self):
        self.model = self.model.cuda()

    def to_cpu(self):
        self.model = self.model.cpu()
wangshankun's avatar
wangshankun committed
457
458
459


class WanVideoIPHandler:
wangshankun's avatar
wangshankun committed
460
    def __init__(self, model_name, repo_or_path, require_grad=False, mode="eval", device="cuda", dtype=torch.float16):
wangshankun's avatar
wangshankun committed
461
462
        # image_processor = CLIPImageProcessor.from_pretrained(
        #     repo_or_path, subfolder='image_processor')
wangshankun's avatar
wangshankun committed
463
        """720P-I2V-diffusers config is
wangshankun's avatar
wangshankun committed
464
465
466
467
468
469
470
471
472
473
            "size": {
                "shortest_edge": 224
            }
        and 480P-I2V-diffusers config is
          "size": {
            "height": 224,
            "width": 224
        }
        but Wan2.1 official use no_crop resize by default
        so I don't use CLIPImageProcessor
wangshankun's avatar
wangshankun committed
474
        """
wangshankun's avatar
wangshankun committed
475
        image_encoder = CLIPVisionModel.from_pretrained(repo_or_path, torch_dtype=dtype)
wangshankun's avatar
wangshankun committed
476
        logger.info(f"Using image encoder {model_name} from {repo_or_path}")
wangshankun's avatar
wangshankun committed
477
        image_encoder.requires_grad_(require_grad)
wangshankun's avatar
wangshankun committed
478
        if mode == "eval":
wangshankun's avatar
wangshankun committed
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
            image_encoder.eval()
        else:
            image_encoder.train()
        self.dtype = dtype
        self.device = device
        self.image_encoder = image_encoder.to(device=device, dtype=dtype)
        self.size = (224, 224)
        mean = [0.48145466, 0.4578275, 0.40821073]
        std = [0.26862954, 0.26130258, 0.27577711]
        self.normalize = T.Normalize(mean=mean, std=std)
        # self.image_processor = image_processor

    def encode(
        self,
        img_tensor: Tensor,
    ):
        if img_tensor.ndim == 5:  # B C T H W
            # img_tensor = img_tensor[:, :, 0]
            img_tensor = rearrange(img_tensor, "B C 1 H W -> B C H W")
wangshankun's avatar
wangshankun committed
498
499
        img_tensor = torch.clamp(img_tensor.float() * 0.5 + 0.5, min=0.0, max=1.0).to(self.device)
        img_tensor = F.interpolate(img_tensor, size=self.size, mode="bicubic", align_corners=False)
wangshankun's avatar
wangshankun committed
500
        img_tensor = self.normalize(img_tensor).to(self.dtype)
wangshankun's avatar
wangshankun committed
501
502
503
504

        image_embeds = self.image_encoder(pixel_values=img_tensor, output_hidden_states=True)

        return image_embeds.hidden_states[-1]