xcomposer.py 5.83 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
from transformers import AutoModel, AutoTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList
from PIL import Image
from ..base import BaseModel
from ...smp import *
from ...dataset import DATASET_TYPE


class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True

        return False


class XComposer(BaseModel):

    INSTALL_REQ = False
    INTERLEAVE = False

    def __init__(self, model_path='internlm/internlm-xcomposer-vl-7b', **kwargs):
        assert model_path is not None
        self.model_path = model_path

        model = AutoModel.from_pretrained(self.model_path, device_map='cpu', trust_remote_code=True).cuda().eval()
        tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        model.tokenizer = tokenizer
        self.model = model
        self.device = self.model.internlm_model.model.embed_tokens.weight.device
        self.eoh = '<TOKENS_UNUSED_0>'
        self.eoa = '<TOKENS_UNUSED_1>'
        stop_words_ids = [
            torch.tensor([103027]).to(self.device),  # end of human
            torch.tensor([103028]).to(self.device),  # end of bot
        ]
        default_kwargs = {
            'max_new_tokens': 512, 'num_beams': 5, 'do_sample': False,
            'min_length': 1, 'repetition_penalty': 1.5, 'length_penalty': 1.0
        }
        default_kwargs.update(kwargs)
        self.kwargs = default_kwargs
        self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

    def generate_inner(self, message, dataset=None):
        if len(message) == 2:
            if message[0]['type'] == 'text' and message[1]['type'] == 'image':
                message = [message[1], message[0]]
        kwargs = cp.deepcopy(self.kwargs)
        if dataset is not None:
            if DATASET_TYPE(dataset) == 'MCQ':
                kwargs['max_new_tokens'] = 5
                kwargs['num_beams'] = 5

        with torch.cuda.amp.autocast():
            with torch.no_grad():
                prompt_embs = self.message_to_prompt_embs(message, dataset)
                outputs = self.model.internlm_model.generate(
                    inputs_embeds=prompt_embs,
                    stopping_criteria=self.stopping_criteria,
                    **kwargs
                )

        output_token = outputs[0]
        if output_token[0] == 0:
            output_token = output_token[1:]
        if output_token[0] == 1:
            output_token = output_token[1:]
        output_text = self.model.tokenizer.decode(output_token, add_special_tokens=False)

        output_text = output_text.split(self.model.eoa)[0]
        output_text = output_text.split('<|Bot|>')[-1].strip()
        return output_text

    def message_to_prompt_embs(self, message, dataset=None):
        assert isinstance(message, list)
        img_embeds = []
        prompt_full = '<|User|>: '
        for msg in message:
            if msg['type'] == 'text':
                prompt_full += msg['value']
            elif msg['type'] == 'image':
                image = Image.open(msg['value']).convert('RGB')
                image = self.model.vis_processor(image).unsqueeze(0).to(self.device)
                img_embeds.append(self.model.encode_img(image))
                prompt_full += '<ImageHere>'

        prompt_full += self.model.eoh + ' <|Bot|>: '
        if dataset is not None and DATASET_TYPE(dataset) == 'MCQ':
            prompt_full += 'Answer: The answer is '
        elif dataset is not None and DATASET_TYPE(dataset) in ['VQA', 'QA', 'Y/N']:
            prompt_full += 'Answer: '

        prompt_segs = prompt_full.split('<ImageHere>')
        assert len(prompt_segs) == len(img_embeds) + 1

        prompt_seg_tokens = [
            self.model.tokenizer(seg, return_tensors='pt', add_special_tokens=(i == 0)).to(self.device).input_ids.long()
            for i, seg in enumerate(prompt_segs)
        ]
        prompt_seg_embs = [self.model.internlm_model.model.embed_tokens(seg) for seg in prompt_seg_tokens]
        all_embeddings = []
        for i in range(len(img_embeds)):
            all_embeddings.extend([prompt_seg_embs[i], img_embeds[i]])
        all_embeddings.append(prompt_seg_embs[-1])
        prompt_embs = torch.cat(all_embeddings, dim=1)
        return prompt_embs

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        if DATASET_TYPE(dataset) == 'MCQ':
            return True
        return False

    def build_prompt(self, line, dataset=None):
        assert dataset is None or isinstance(dataset, str)
        assert self.use_custom_prompt(dataset)
        tgt_path = self.dump_image(line, dataset)

        question = line['question']
        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        options_prompt = ''
        for key, item in options.items():
            options_prompt += f'{key}. {item}\n'
        hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
        context = 'N/A' if hint is None else hint
        mid_prompt = 'Context: ' + context + '\nQuestion: ' + question
        if len(options_prompt):
            mid_prompt += '\nOptions: ' + options_prompt

        if len(options):
            txt_prompt = 'Please answer this question by choosing the correct choice.'
        else:
            txt_prompt = 'Please answer this question directly. '
        prompt = txt_prompt + mid_prompt
        message = [dict(type='text', value=prompt)]
        message.extend([dict(type='image', value=s) for s in tgt_path])
        return message