_clip.py 2.86 KB
Newer Older
yuguo960516's avatar
bloom  
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import sys
from collections import namedtuple

import oneflow as flow
from oneflow import nn

from .models import l2norm


def import_flow_clip(fn):
    def wrapper(*args, **kwargs):
        sys.path.append(
            os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")), "CLIP")
        )
        fn(*args, **kwargs)
        sys.path.pop()

    return wrapper


EmbeddedText = namedtuple("EmbedTextReturn", ["text_embed", "text_encodings"])
EmbeddedImage = namedtuple("EmbedImageReturn", ["image_embed", "image_encodings"])


class BaseClipAdapter(nn.Module):
    def __init__(self, clip, **kwargs):
        super().__init__()
        self.clip = clip
        self.overrides = kwargs

    @property
    def dim_latent(self):
        raise NotImplementedError

    @property
    def image_size(self):
        raise NotImplementedError

    @property
    def image_channels(self):
        raise NotImplementedError

    @property
    def max_text_len(self):
        raise NotImplementedError

    def embed_text(self, text):
        raise NotImplementedError

    def embed_image(self, image):
        raise NotImplementedError


class OpenAIClipAdapter(BaseClipAdapter):
    @import_flow_clip
    def __init__(self, name="ViT-L/14"):
        import clip

        openai_clip, preprocess = clip.load(name)
        super().__init__(openai_clip)
        self.eos_id = 49407  # for handling 0 being also '!'

        text_attention_final = self.find_layer("ln_final")
        self.handle = text_attention_final.register_forward_hook(self._hook)
        self.clip_normalize = preprocess.transforms[-1]
        self.cleared = False

    def find_layer(self, layer):
        modules = dict([*self.clip.named_modules()])
        return modules.get(layer, None)

    def clear(self):
        if self.cleared:
            return

        self.handle()

    def _hook(self, _, inputs, outputs):
        self.text_encodings = outputs

    @property
    def dim_latent(self):
        return 512

    @property
    def image_size(self):
        return self.clip.visual.input_resolution

    @property
    def image_channels(self):
        return 3

    @property
    def max_text_len(self):
        return self.clip.context_length

    @flow.no_grad()
    def embed_text(self, text):
        text = text[..., : self.max_text_len]

        assert not self.cleared
        text_mask = text != 0  # v0.15.4

        text_embed = self.clip.encode_text(text)
        text_encodings = self.text_encodings
        del self.text_encodings
        return l2norm(text_embed.float()), text_encodings.float(), text_mask

    @flow.no_grad()
    def embed_image(self, image):
        assert not self.cleared
        image = self.validate_and_resize_image(image)
        image = self.clip_normalize(image)
        image_embed = self.clip.encode_image(image)
        return EmbeddedImage(l2norm(image_embed.float()), None)