run_opt.py 4.67 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from vary.utils.conversation import conv_templates, SeparatorStyle
from vary.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from vary.model import *
from vary.utils.utils import KeywordsStoppingCriteria

from PIL import Image

import os
import requests
from PIL import Image
from io import BytesIO

from transformers import TextStreamer


from vary.model.plug.blip_process import BlipImageEvalProcessor

from vary.model.vision_encoder.sam import build_sam_vit_b
from vary.model.plug.transforms import train_transform, test_transform
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'

# 这段代码是用于加载和评估一个预训练模型的。它使用了Hugging Face的Transformers库,以及PIL库来处理图像。

def load_image(image_file):
    """
    加载图像。
    
    参数:
    - image_file: 图像的文件路径或URL。
    
    返回值:
    - image: 打开的图像,转换为RGB模式。
    """
    # 判断图像文件是否是URL
    if image_file.startswith('http') or image_file.startswith('https'):
        # 从URL下载图像
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        # 从本地文件系统打开图像
        image = Image.open(image_file).convert('RGB')
    return image


def eval_model(args):
    """
    评估模型性能。

    参数:
    - args: 一个包含模型名称和图像文件路径等信息的参数对象。

    说明:
    此函数加载指定的模型和对应的tokenizer,将输入的图像转换为文本描述。
    使用关键词停止准则以防止生成不必要的文本。
    """
    # Model
    # 初始化模型和tokenizer
    disable_torch_init() # 禁用PyTorch的默认初始化
    model_name = os.path.expanduser(args.model_name) # 解析模型名称

    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) # 解析模型名称

    model = varyOPTForCausalLM.from_pretrained(model_name) # 从预训练加载模型



    # 将模型转移到GPU并设置数据类型
    model.to(device='cuda',  dtype=torch.bfloat16) 

    # 设置图像处理和token长度
    image_processor_high =  test_transform    # 定义图像预处理
    image_token_len = 256 # 图像token长度

    # 构建输入prompt
    prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
    inputs = tokenizer([prompt]) # 对prompt进行tokenization

    # 加载输入图像
    image = load_image(args.image_file)
    image_1 = image.copy()
    # 加载输入图像
    image_tensor_1 = image_processor_high(image_1).to(torch.bfloat16)

    # 准备模型输入
    input_ids = torch.as_tensor(inputs.input_ids).cuda()
    # 设置停止准则
    stop_str = '</s>'
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

    # 设置文本流
    """
    这段代码首先创建了一个TextStreamer实例,用于处理模型生成的文本流。然后,它在CUDA设备上以mixed precision模式运行模型的generate方法,生成基于输入图像的新tokens。
    这个过程包括输入tokens、图像张量、随机采样设置、束搜索参数、文本流处理器、最大新生成tokens数以及停止准则。
    """
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    # 使用mixed precision进行模型推理
    with torch.autocast("cuda", dtype=torch.bfloat16):
        output_ids = model.generate(
            input_ids,
            images=[(image_tensor_1.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).cuda())],
            do_sample=True,
            num_beams = 1,
            streamer=streamer,
            max_new_tokens=4096,
            stopping_criteria=[stopping_criteria]
            )
        



        # input_token_len = input_ids.shape[1]
        # outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()

        # if outputs.endswith(stop_str):
        #     outputs = outputs[:-len(stop_str)]
        # outputs = outputs.strip()

        # print(outputs)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
    parser.add_argument("--image-file", type=str, required=True)
    # parser.add_argument("--query", type=str, required=True)
    parser.add_argument("--conv-mode", type=str, default=None)
    args = parser.parse_args()

    eval_model(args)