blip2.py 2.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import transformers

from ..registry import ModelAttribute, model_zoo

# ===============================
# Register single-image SAM
# ===============================


# define data gen function
def data_gen():
    # Generated from following code snippet
    #
    # from PIL import Image
    # import requests
    # from transformers import Blip2Processor, Blip2Model
    # import torch

    # processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
    # url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    # image = Image.open(requests.get(url, stream=True).raw)

    # prompt = "Question: how many cats are there? Answer:"
    # inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)

    pixel_values = torch.rand(1, 3, 224, 224, dtype=torch.float32)
    input_ids = torch.tensor([[2, 45641, 35, 141, 171, 10017, 32, 89, 116, 31652, 35]], dtype=torch.int64)
    attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
    labels = torch.tensor([[34, 56]], dtype=torch.int64)
    return dict(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=labels)


# define output transform function
output_transform_fn = lambda x: x

# define loss funciton
loss_fn_blip2_model = lambda x: x.loss

config = transformers.Blip2Config()
config.text_config.num_hidden_layers = 1
config.qformer_config.num_hidden_layers = 1
config.vision_config.num_hidden_layers = 1
config.qformer_config.attention_probs_dropout_prob = 0
config.qformer_config.hidden_dropout_prob = 0
config.text_config.dropout = 0

# register the blip2 variants
model_zoo.register(name='transformers_blip2',
                   model_fn=lambda: transformers.Blip2Model(config),
                   data_gen_fn=data_gen,
                   output_transform_fn=output_transform_fn,
                   loss_fn=loss_fn_blip2_model,
                   model_attribute=ModelAttribute(has_control_flow=True))

model_zoo.register(name='transformers_blip2_conditional_gerneration',
                   model_fn=lambda: transformers.Blip2ForConditionalGeneration(config),
                   data_gen_fn=data_gen,
                   output_transform_fn=output_transform_fn,
                   loss_fn=loss_fn_blip2_model,
                   model_attribute=ModelAttribute(has_control_flow=True))