model.py 1.62 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
import torch
from transformers import CLIPTextModel, AutoTokenizer


Dongz's avatar
Dongz committed
5
class TextEncoderHFClipModel:
helloyongyang's avatar
helloyongyang committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
    def __init__(self, model_path, device):
        self.device = device
        self.model_path = model_path
        self.init()
        self.load()

    def init(self):
        self.max_length = 77

    def load(self):
        self.model = CLIPTextModel.from_pretrained(self.model_path).to(torch.float16).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, padding_side="right")

    def to_cpu(self):
        self.model = self.model.to("cpu")

    def to_cuda(self):
        self.model = self.model.to("cuda")

    @torch.no_grad()
26
27
    def infer(self, text, config):
        if config.cpu_offload:
helloyongyang's avatar
helloyongyang committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
            self.to_cuda()
        tokens = self.tokenizer(
            text,
            return_length=False,
            return_overflowing_tokens=False,
            return_attention_mask=True,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        ).to("cuda")

        outputs = self.model(
            input_ids=tokens["input_ids"],
            attention_mask=tokens["attention_mask"],
            output_hidden_states=False,
        )

        last_hidden_state = outputs["pooler_output"]
47
        if config.cpu_offload:
helloyongyang's avatar
helloyongyang committed
48
49
50
51
52
            self.to_cpu()
        return last_hidden_state, tokens["attention_mask"]


if __name__ == "__main__":
53
54
    model_path = ""
    model = TextEncoderHFClipModel(model_path, torch.device("cuda"))
Dongz's avatar
Dongz committed
55
    text = "A cat walks on the grass, realistic style."
helloyongyang's avatar
helloyongyang committed
56
57
    outputs = model.infer(text)
    print(outputs)