model.py 1.65 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
root's avatar
root committed
2
from loguru import logger
PengGao's avatar
PengGao committed
3
from transformers import AutoTokenizer, CLIPTextModel
helloyongyang's avatar
helloyongyang committed
4
5


Dongz's avatar
Dongz committed
6
class TextEncoderHFClipModel:
helloyongyang's avatar
helloyongyang committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
    def __init__(self, model_path, device):
        self.device = device
        self.model_path = model_path
        self.init()
        self.load()

    def init(self):
        self.max_length = 77

    def load(self):
        self.model = CLIPTextModel.from_pretrained(self.model_path).to(torch.float16).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, padding_side="right")

    def to_cpu(self):
        self.model = self.model.to("cpu")

    def to_cuda(self):
        self.model = self.model.to("cuda")

    @torch.no_grad()
27
28
    def infer(self, text, config):
        if config.cpu_offload:
helloyongyang's avatar
helloyongyang committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
            self.to_cuda()
        tokens = self.tokenizer(
            text,
            return_length=False,
            return_overflowing_tokens=False,
            return_attention_mask=True,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        ).to("cuda")

        outputs = self.model(
            input_ids=tokens["input_ids"],
            attention_mask=tokens["attention_mask"],
            output_hidden_states=False,
        )

        last_hidden_state = outputs["pooler_output"]
48
        if config.cpu_offload:
helloyongyang's avatar
helloyongyang committed
49
50
51
52
53
            self.to_cpu()
        return last_hidden_state, tokens["attention_mask"]


if __name__ == "__main__":
54
55
    model_path = ""
    model = TextEncoderHFClipModel(model_path, torch.device("cuda"))
Dongz's avatar
Dongz committed
56
    text = "A cat walks on the grass, realistic style."
helloyongyang's avatar
helloyongyang committed
57
    outputs = model.infer(text)
root's avatar
root committed
58
    logger.info(outputs)