vxverse.py 4.15 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import sys
import os.path as osp
import warnings
from .base import BaseModel
from transformers import StoppingCriteriaList

from PIL import Image
from huggingface_hub import snapshot_download
from vlmeval.smp import *

model_cfgs = {
    'XVERSE-V-13B': {
        'arch': 'vxverse',
        'model_type': 'pretrain_xverse13b-chat',
        'max_txt_len': 512,
        'end_sym': '<|endoftext|>',
        'low_resource': False,
        'prompt_template': 'Human: {}\nAssistant: ',
        'ckpt': 'xverse/XVERSE-V-13B',
        'lora_r': 128,
        'lora_alpha': 256,
        'lora_dropout': 0.05,
        'lora_target_modules': 'all_linear',
        'has_qformer': False,
        'n_proj_layers': 2,
        'vit_model': 'openai/clip-vit-large-patch14',
        'vit_path': 'openai/clip-vit-large-patch14',
        'image_size': 224,
        'drop_path_rate': 0,
        'vit_precision': 'fp16',
        'llama_model': 'xverse/XVERSE-13B-Chat',
    }
}


class VXVERSE(BaseModel):

    INSTALL_REQ = True
    INTERLEAVE = False

    def __init__(self, model_name='XVERSE-V-13B', root=None, **kwargs):
        from omegaconf import OmegaConf
        if root is None:
            warnings.warn('Please set root to the directory of vxverse.')

        if model_name == 'XVERSE-V-13B':
            cfg = model_cfgs['XVERSE-V-13B']
        else:
            raise NotImplementedError

        ckpt_dir = cfg['ckpt']
        if not osp.isdir(ckpt_dir):
            cache_path = get_cache_path(ckpt_dir)
            if cache_path is not None:
                ckpt_dir = cache_path
            else:
                ckpt_dir = snapshot_download(repo_id=ckpt_dir)
        assert osp.exists(ckpt_dir) and osp.isdir(ckpt_dir)
        ckpt = osp.join(ckpt_dir, 'adapter_and_lora.bin')
        cfg['ckpt'] = ckpt
        model_cfg = OmegaConf.create(cfg)

        self.model_name = model_name

        self.root = root
        sys.path.append(self.root)

        from vxverse.common.registry import registry
        from vxverse.conversation.conversation import CONV_VISION_XVERSE

        device = torch.cuda.current_device()
        self.device = device

        model_cls = registry.get_model_class(model_cfg.arch)
        model = model_cls.from_config(model_cfg)
        model = model.to(device)
        model.eval()
        vis_processor_cfg = OmegaConf.create(dict(name='hd_image_train', image_size=224))
        vis_processor = registry.get_processor_class(
            vis_processor_cfg.name
        ).from_config(vis_processor_cfg)

        self.model = model
        self.vis_processor = vis_processor
        self.vis_processor_cfg = vis_processor_cfg

        self.CONV_VISION = CONV_VISION_XVERSE
        self.CONV_VISION.system = ''
        stop_words_ids = [[835], [2277, 29937]]
        self.stop_words_ids = stop_words_ids
        default_kwargs = dict(max_new_tokens=512)
        default_kwargs.update(kwargs)
        self.kwargs = default_kwargs

    def generate_inner(self, message, dataset=None):
        prompt, image_path = self.message_to_promptimg(message, dataset=dataset)

        image = Image.open(image_path).convert('RGB')
        image = self.vis_processor(image)

        if self.vis_processor_cfg.name == 'hd_image_train':
            patches_per_image = [[image.shape[0]]]
            image = [image]
        else:
            patches_per_image = None
            image = image.unsqueeze(0)

        chat_state = self.CONV_VISION.copy()
        texts = self.prepare_texts([prompt], chat_state)
        texts = [text.lstrip() for text in texts]
        answers = self.model.generate(
            image,
            texts,
            patches_per_images=patches_per_image,
            do_sample=False,
            stop_words_ids=self.stop_words_ids,
            **self.kwargs
        )
        return answers[0]

    def prepare_texts(self, texts, conv_temp):
        convs = [conv_temp.copy() for _ in range(len(texts))]
        [
            conv.append_message(conv.roles[0], '<ImageHere>\n{}'.format(text))
            for conv, text in zip(convs, texts)
        ]
        [conv.append_message(conv.roles[1], None) for conv in convs]
        texts = [conv.get_prompt() for conv in convs]
        return texts