inference_15.py 3.58 KB
Newer Older
mashun1's avatar
uitars  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch

import re
import json


from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor


def standardize_messages(messages):
    for message in messages:
        if isinstance(message["content"], str):
            # 如果是字符串,转换成标准格式
            message["content"] = [
                {
                    "type": "text",
                    "text": message["content"]
                }
            ]
    return messages


def add_box_token(input_string):
    # Step 1: Split the string into individual actions
    if "Action: " in input_string and "start_box=" in input_string:
        suffix = input_string.split("Action: ")[0] + "Action: "
        actions = input_string.split("Action: ")[1:]
        processed_actions = []
        for action in actions:
            action = action.strip()
            # Step 2: Extract coordinates (start_box or end_box) using regex
            coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action)
            
            updated_action = action  # Start with the original action
            for coord_type, x, y in coordinates:
                # Convert x and y to integers
                updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'")
            processed_actions.append(updated_action)
        
        # Step 5: Reconstruct the final string
        final_string = suffix + "\n\n".join(processed_actions)
    else:
        final_string = input_string
    return final_string


def standardize_messages(messages):
    for message in messages:
        if isinstance(message["content"], str):
            # 如果是字符串,转换成标准格式
            message["content"] = [
                {
                    "type": "text",
                    "text": message["content"]
                }
            ]
    return messages


result = {}
messages = json.load(open("./data/test_messages_07.json"))

for message in messages:
    if message["role"] == "assistant":
        message["content"] = add_box_token(message["content"])
        print(message["content"])


messages = standardize_messages(messages)


if __name__ == "__main__":
    from argparse import ArgumentParser
    
    from utils.coordinate_extract import extract_coordinates
    from utils.plot_image import plot
    
    parser = ArgumentParser()
    
    parser.add_argument("--model_path", type=str, required=True)
    
    args = parser.parse_args()
    
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            args.model_path,
            torch_dtype=torch.float16,
            device_map="auto",
            attn_implementation="sdpa"
        )
    
    processor = AutoProcessor.from_pretrained(args.model_path)
    
    
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt"
    ).to("cuda")
    
    
    generated_ids = model.generate(**inputs, max_new_tokens=400)
    
    generated_ids_trimmed = [
                out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    
    print(output_text)
    
    model_raw_response = output_text[0]
    
    coordinate = extract_coordinates(model_raw_response)

    # Please use re to parse the coordinate values
    model_output_width, model_output_height = coordinate
    
    plot(model_output_width, model_output_height)