import torch from transformers import AutoModelForVision2Seq, AutoProcessor import json import re def standardize_messages(messages): for message in messages: if isinstance(message["content"], str): # 如果是字符串,转换成标准格式 message["content"] = [ { "type": "text", "text": message["content"] } ] return messages def add_box_token(input_string): # Step 1: Split the string into individual actions if "Action: " in input_string and "start_box=" in input_string: suffix = input_string.split("Action: ")[0] + "Action: " actions = input_string.split("Action: ")[1:] processed_actions = [] for action in actions: action = action.strip() # Step 2: Extract coordinates (start_box or end_box) using regex coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action) updated_action = action # Start with the original action for coord_type, x, y in coordinates: # Convert x and y to integers updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'") processed_actions.append(updated_action) # Step 5: Reconstruct the final string final_string = suffix + "\n\n".join(processed_actions) else: final_string = input_string return final_string result = {} messages = json.load(open("./data/test_messages_07.json")) for message in messages: if message["role"] == "assistant": message["content"] = add_box_token(message["content"]) print(message["content"]) messages = standardize_messages(messages) if __name__ == "__main__": from argparse import ArgumentParser from utils.coordinate_extract import extract_coordinates from utils.plot_image import plot parser = ArgumentParser() parser.add_argument("--model_path", type=str, required=True) args = parser.parse_args() model = AutoModelForVision2Seq.from_pretrained( args.model_path, torch_dtype=torch.float16, device_map="auto" ) processor = AutoProcessor.from_pretrained(args.model_path) # if hasattr(processor, "image_processor"): # image_processor = processor.image_processor # # 替换 size # image_processor.size = { # "shortest_edge": 768, # "longest_edge": 1024 # } # processor.image_processor = image_processor inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to("cuda") generated_ids = model.generate(**inputs, max_new_tokens=400) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) print(output_text) model_raw_response = output_text[0] coordinate = extract_coordinates(model_raw_response) # Please use re to parse the coordinate values model_output_width, model_output_height = coordinate plot(model_output_width, model_output_height)