vila.py 4.43 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
from PIL import Image
from abc import abstractproperty
import sys
import os.path as osp
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE
import copy


class VILA(BaseModel):
    INSTALL_REQ = True
    INTERLEAVE = True

    def __init__(self,
                 model_path='Efficient-Large-Model/Llama-3-VILA1.5-8b',
                 **kwargs):
        try:
            from llava.model.builder import load_pretrained_model
            from llava.mm_utils import get_model_name_from_path
            from llava.mm_utils import process_images, tokenizer_image_token, KeywordsStoppingCriteria
            from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN # noqa E501
            from llava.conversation import conv_templates, SeparatorStyle
        except Exception as err:
            logging.critical('Please install VILA before using VILA')
            logging.critical('Please install VILA from https://github.com/NVlabs/VILA')
            logging.critical('Please install VLMEvalKit after installing VILA')
            logging.critical('VILA is supported only with transformers==4.36.2')
            raise err

        warnings.warn('Please install the latest version of VILA from GitHub before you evaluate the VILA model.')
        assert osp.exists(model_path) or len(model_path.split('/')) == 2

        model_name = get_model_name_from_path(model_path)

        try:
            self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
                model_path=model_path,
                model_base=None,
                model_name=model_name,
                device='cpu',
                device_map='cpu'
            )
        except Exception as err:
            logging.critical('Error loading VILA model: ')
            raise err

        self.model = self.model.cuda()
        if '3b' in model_path:
            self.conv_mode = 'vicuna_v1'
        if '8b' in model_path:
            self.conv_mode = 'llama_3'
        elif '13b' in model_path:
            self.conv_mode = 'vicuna_v1'
        elif '40b' in model_path:
            self.conv_mode = 'hermes-2'

        kwargs_default = dict(do_sample=False, temperature=0, max_new_tokens=512, top_p=None, num_beams=1, use_cache=True) # noqa E501

        kwargs_default.update(kwargs)
        self.kwargs = kwargs_default
        warnings.warn(f'Using the following kwargs for generation config: {self.kwargs}')

        self.conv_templates = conv_templates
        self.process_images = process_images
        self.tokenizer_image_token = tokenizer_image_token
        self. DEFAULT_IMAGE_TOKEN = DEFAULT_IMAGE_TOKEN
        self.SeparatorStyle = SeparatorStyle
        self.IMAGE_TOKEN_INDEX = IMAGE_TOKEN_INDEX
        self.KeywordsStoppingCriteria = KeywordsStoppingCriteria

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        # TODO see if custom prompt needed
        return False

    def generate_inner(self, message, dataset=None):

        content, images = '', []

        for msg in message:
            if msg['type'] == 'text':
                content += msg['value']
            elif msg['type'] == 'image':
                image = Image.open(msg['value']).convert('RGB')
                images.append(image)
                content += (self.DEFAULT_IMAGE_TOKEN + '\n')

        image_tensor = self.process_images(
            images, self.image_processor,
            self.model.config).to(self.model.device, dtype=torch.float16)

        # Support interleave text and image
        conv = self.conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], content)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input_ids = self.tokenizer_image_token(prompt, self.tokenizer, self.IMAGE_TOKEN_INDEX,
                                               return_tensors='pt').unsqueeze(0).cuda()

        stop_str = conv.sep if conv.sep_style != self.SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]
        stopping_criteria = self.KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)

        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids, images=image_tensor, stopping_criteria=[stopping_criteria], **self.kwargs)

            output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
        return output