model.py 1.65 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
import torch
from transformers import CLIPTextModel, AutoTokenizer


Dongz's avatar
Dongz committed
5
class TextEncoderHFClipModel:
helloyongyang's avatar
helloyongyang committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    def __init__(self, model_path, device):
        self.device = device
        self.model_path = model_path
        self.init()
        self.load()

    def init(self):
        self.max_length = 77

    def load(self):
        self.model = CLIPTextModel.from_pretrained(self.model_path).to(torch.float16).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, padding_side="right")

    def to_cpu(self):
        self.model = self.model.to("cpu")

    def to_cuda(self):
        self.model = self.model.to("cuda")

    @torch.no_grad()
    def infer(self, text, args):
        if args.cpu_offload:
            self.to_cuda()
        tokens = self.tokenizer(
            text,
            return_length=False,
            return_overflowing_tokens=False,
            return_attention_mask=True,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        ).to("cuda")

        outputs = self.model(
            input_ids=tokens["input_ids"],
            attention_mask=tokens["attention_mask"],
            output_hidden_states=False,
        )

        last_hidden_state = outputs["pooler_output"]
        if args.cpu_offload:
            self.to_cpu()
        return last_hidden_state, tokens["attention_mask"]


if __name__ == "__main__":
    model = TextEncoderHFClipModel("/mnt/nvme0/yongyang/projects/hy/HunyuanVideo/ckpts/text_encoder_2", torch.device("cuda"))
Dongz's avatar
Dongz committed
54
    text = "A cat walks on the grass, realistic style."
helloyongyang's avatar
helloyongyang committed
55
56
    outputs = model.infer(text)
    print(outputs)