vit.py 2.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
import torch
import transformers

from ..registry import ModelAttribute, model_zoo

# ===============================
# Register single-sentence VIT
# ===============================

10
config = transformers.ViTConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)
11
12
13
14
15
16
17
18
19
20


# define data gen function
def data_gen():
    pixel_values = torch.randn(1, 3, 224, 224)
    return dict(pixel_values=pixel_values)


def data_gen_for_image_classification():
    data = data_gen()
21
    data["labels"] = torch.tensor([0])
22
23
24
25
26
    return data


def data_gen_for_masked_image_modeling():
    data = data_gen()
27
    num_patches = (config.image_size // config.patch_size) ** 2
28
    bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
29
    data["bool_masked_pos"] = bool_masked_pos
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    return data


# define output transform function
output_transform_fn = lambda x: x

# function to get the loss
loss_fn_for_vit_model = lambda x: x.pooler_output.mean()
loss_fn_for_image_classification = lambda x: x.logits.mean()
loss_fn_for_masked_image_modeling = lambda x: x.loss

# register the following models
# transformers.ViTModel,
# transformers.ViTForMaskedImageModeling,
# transformers.ViTForImageClassification,
45
46
47
48
49
50
51
52
model_zoo.register(
    name="transformers_vit",
    model_fn=lambda: transformers.ViTModel(config),
    data_gen_fn=data_gen,
    output_transform_fn=output_transform_fn,
    loss_fn=loss_fn_for_vit_model,
    model_attribute=ModelAttribute(has_control_flow=True),
)
53

54
55
56
57
58
59
60
61
model_zoo.register(
    name="transformers_vit_for_masked_image_modeling",
    model_fn=lambda: transformers.ViTForMaskedImageModeling(config),
    data_gen_fn=data_gen_for_masked_image_modeling,
    output_transform_fn=output_transform_fn,
    loss_fn=loss_fn_for_masked_image_modeling,
    model_attribute=ModelAttribute(has_control_flow=True),
)
62

63
64
65
66
67
68
69
70
model_zoo.register(
    name="transformers_vit_for_image_classification",
    model_fn=lambda: transformers.ViTForImageClassification(config),
    data_gen_fn=data_gen_for_image_classification,
    output_transform_fn=output_transform_fn,
    loss_fn=loss_fn_for_image_classification,
    model_attribute=ModelAttribute(has_control_flow=True),
)