models_lit.py 7.59 KB
Newer Older
suily's avatar
suily committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Models from Locked-image text Tuning.

See paper https://arxiv.org/abs/2111.07991
"""

import dataclasses
import os
from typing import Optional, Tuple

import flax.linen as nn
import jax.numpy as jnp
import ml_collections
from vit_jax import checkpoint
from vit_jax import models_vit
from vit_jax import preprocess

from flaxformer.architectures.bert import bert
from flaxformer.architectures.bert import configs


BASE_PATH = 'gs://vit_models/lit'


class BertModel(nn.Module):
  """BERT encoder with linear projection on last layer CLS token."""

  config: str
  num_classes: Optional[int] = None

  @nn.compact
  def __call__(self, tokens):
    out = {}

    batch_size, max_len = tokens.shape
    bert_model = bert.BertEncoder(**dataclasses.asdict({
        'base': configs.BertBaseConfig(),
        'large': configs.BertLargeConfig(),
    }[self.config]))
    x = out['transformed'] = bert_model(
        token_ids=tokens,
        position_ids=jnp.tile(
            jnp.arange(0, max_len, dtype=jnp.int32), [batch_size, 1]),
        segment_ids=jnp.zeros([batch_size, max_len], dtype=jnp.int32),
        input_mask=tokens.astype(jnp.bool_).astype(jnp.int32),
        enable_dropout=False,
    )

    x = out['pre_logits'] = x[:, 0]  # CLS token

    if self.num_classes:
      x = out['logits'] = nn.Dense(self.num_classes, name='head')(x)

    return x, out


class TextTransformer(nn.Module):
  """Simple text transformer."""

  num_classes: int
  width: int = 512
  num_layers: int = 12
  mlp_dim: int = 2048
  num_heads: int = 8
  dropout_rate: float = 0.0
  vocab_size: int = 32_000

  @nn.compact
  def __call__(self, x):
    out = {}

    embedding = nn.Embed(num_embeddings=self.vocab_size, features=self.width)
    x = out['embedded'] = embedding(x)

    # Add posemb
    n, l, d = x.shape  # pylint: disable=unused-variable
    x = x + self.param('pos_embedding',
                       nn.initializers.normal(stddev=1 / jnp.sqrt(d)),
                       (1, l, d), x.dtype)

    x = models_vit.Encoder(
        num_layers=self.num_layers,
        mlp_dim=self.mlp_dim,
        num_heads=self.num_heads,
        dropout_rate=self.dropout_rate,
        attention_dropout_rate=0,
        add_position_embedding=False)(
            x, train=False)

    x = out['pre_logits'] = x[:, -1, :]  # note that we take *last* token
    x = out['logits'] = nn.Dense(self.num_classes, name='head')(x)

    return x, out


class LitModel(nn.Module):
  """Locked-image text Tuning model.

  See paper https://arxiv.org/abs/2111.07991

  For examples, refer to Colab

  https://colab.research.google.com/github/google-research/vision_transformer/blob/main/lit.ipynb

  Attributes:
    image: Configuration for ViT image tower.
    text: Configuration for text tower.
    pp: Preprocessing configuration.
    out_dim: Size of optional image/text heads that are added to the towers.
    model_name: Refers to the key in `model_configs.MODEL_CONFIGS`.
  """

  image: ml_collections.ConfigDict
  text_model: str
  text: ml_collections.ConfigDict
  pp: ml_collections.ConfigDict
  out_dim: Tuple[Optional[int], Optional[int]]
  model_name: str

  def load_variables(self, path=None, cache=True):
    """Loads variables.

    Args:
      path: Path to load params from. If not specified, then the parms will be
        loaded from the default public Cloud storage path, unless they exist in
        the current working directory.
      cache: If set to `True` and `path` is not specified (the default), then
        the files will be copied from Cloud and stored in the current working
        directory.

    Returns:
      The module variables, to be used with `model.apply()`
    """
    if path is None:
      local_path = f'{self.model_name}.npz'
      if not os.path.exists(local_path):
        path = f'{BASE_PATH}/{self.model_name}.npz'
        print('Loading params from cloud:', path)
        if cache:
          checkpoint.copy(path, local_path)
      if os.path.exists(local_path):
        print('\n⚠️ Reusing local copy:', local_path)
        path = local_path
    return {'params': checkpoint.load(path)}

  @property
  def vocab_path(self):
    ext = {
        'bert': 'txt',
        'sentencepiece': 'model',
    }[self.pp.tokenizer_name]
    return f'{BASE_PATH}/{self.model_name}.{ext}'

  def get_pp(self, crop=False):
    """Returns a preprocessing function suitable for `tf.data.Dataset.map()`."""
    return preprocess.get_pp(
        tokenizer_name=self.pp.tokenizer_name,
        vocab_path=self.vocab_path,
        max_len=self.pp.max_len,
        size=self.pp.size,
        crop=crop)

  def get_tokenizer(self):
    """Returns a tokenizer."""
    return preprocess.get_tokenizer(self.pp.tokenizer_name)(
        vocab_path=self.vocab_path,
        max_len=self.pp.max_len)

  def get_image_preprocessing(self, crop=False):
    """Returns a function to pre-process images (resize, value range)."""
    return preprocess.PreprocessImages(size=self.pp.size, crop=crop)

  @nn.compact
  def __call__(self, *, images=None, tokens=None):
    """Embeds images and/or tokens.

    Args:
      images: Batch of images, prepared with the function returned by
        `get_image_preprocessing()` or `get_pp()`.
      tokens: Batch of tokens, prepared with the function returned by
        `get_tokenizer()` or `get_pp()`.

    Returns:
      A tuple of `(zimg, ztxt, out)`, where `zimg` is a batch of embeddings for
      the images (or `None`, if images were not specified), `ztxt` is a batch
      of embeddings for the tokens (or `None`, if tokens were not specified),
      and `out` is a dictionary of additional values, such as `out['t']` that
      is the temperature multiplied with the vector dot products before the
      softmax is applied.
    """

    # Support calling without text or without images, for example for few-shot.
    ztxt, zimg = None, None
    out = {}
    out_dims = self.out_dim
    if isinstance(out_dims, int):
      out_dims = (out_dims, out_dims)

    if tokens is not None:
      # Embed the text:
      model_class = {
          'bert': BertModel,
          'text_transformer': TextTransformer,
      }[self.text_model]
      text_model = model_class(
          **{
              'num_classes': out_dims[1],
              **(self.text or {})
          }, name='txt')

      ztxt, out_txt = text_model(tokens)
      for k, v in out_txt.items():
        out[f'txt/{k}'] = v

      # Normalize the embeddings the models give us.
      out['txt/norm'] = jnp.linalg.norm(ztxt, axis=1, keepdims=True)
      out['txt/normalized'] = ztxt = ztxt / (out['txt/norm'] + 1e-8)

    if images is not None:
      image_model = models_vit.VisionTransformer(
          **{
              **self.image,
              'num_classes': out_dims[0],
          }, name='img')  # pylint: disable=not-a-mapping
      zimg = image_model(images, train=False)

      # Normalize the embeddings the models give us.
      out['img/norm'] = jnp.linalg.norm(zimg, axis=1, keepdims=True)
      out['img/normalized'] = zimg = zimg / (out['img/norm'] + 1e-8)

    t = self.param('t', nn.initializers.zeros, (1,), jnp.float32)
    out['t'] = jnp.exp(t)

    return zimg, ztxt, out