import argparse import os import re import torch from internvl.model.internvl_chat import InternVLChatModel from internvl.train.dataset import build_transform, dynamic_preprocess from PIL import Image from tqdm import tqdm from transformers import AutoTokenizer def load_image(image_file, input_size=224): image = Image.open(image_file).convert('RGB') transform = build_transform(is_train=False, input_size=input_size) if args.dynamic: images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=use_thumbnail, max_num=args.max_num) else: images = [image] pixel_values = [transform(image) for image in images] pixel_values = torch.stack(pixel_values) return pixel_values def post_processing(response): response = response.replace('\n', '').replace('不是', 'No').replace('是', 'Yes').replace('否', 'No') response = response.lower().replace('true', 'yes').replace('false', 'no') pattern = re.compile(r'[\u4e00-\u9fa5]') response = re.sub(pattern, '', response) return response if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--checkpoint', type=str, default='') parser.add_argument('--root', type=str, default='./Your_Results') parser.add_argument('--num-beams', type=int, default=5) parser.add_argument('--top-k', type=int, default=50) parser.add_argument('--top-p', type=float, default=0.9) parser.add_argument('--sample', type=bool, default=False) parser.add_argument('--dynamic', action='store_true') parser.add_argument('--max-num', type=int, default=6) parser.add_argument('--load-in-8bit', action='store_true') parser.add_argument('--load-in-4bit', action='store_true') parser.add_argument('--auto', action='store_true') args = parser.parse_args() if args.auto: os.environ['CUDA_LAUNCH_BLOCKING'] = '1' kwargs = {'device_map': 'auto'} if args.auto else {} prompt = 'Answer the question using a single word or phrase.' tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) model = InternVLChatModel.from_pretrained( args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit, **kwargs).eval() if not args.load_in_8bit and not args.load_in_4bit and not args.auto: model = model.cuda() image_size = model.config.force_image_size or model.config.vision_config.image_size use_thumbnail = model.config.use_thumbnail total_params = sum(p.numel() for p in model.parameters()) / 1e9 if total_params > 20 or args.dynamic: args.num_beams = 1 print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}') else: print(f'[test] total_params: {total_params}B') print(f'[test] image_size: {image_size}') print(f'[test] template: {model.config.template}') print(f'[test] dynamic_image_size: {args.dynamic}') print(f'[test] use_thumbnail: {use_thumbnail}') print(f'[test] max_num: {args.max_num}') output = os.path.basename(args.checkpoint) os.makedirs(output, exist_ok=True) for filename in os.listdir(args.root): fin = open(os.path.join(args.root, filename), 'r', encoding='utf-8') fout = open(os.path.join(output, filename), 'w', encoding='utf-8') lines = fin.readlines() filename = filename.replace('.txt', '') for line in tqdm(lines): img, question, gt = line.strip().split('\t') question = question + ' ' + prompt img_path = os.path.join('../../data/mme/MME_Benchmark_release_version', filename, img) assert os.path.exists(img_path), img_path pixel_values = load_image(img_path, image_size).cuda().to(torch.bfloat16) generation_config = dict( do_sample=args.sample, top_k=args.top_k, top_p=args.top_p, num_beams=args.num_beams, max_new_tokens=20, eos_token_id=tokenizer.eos_token_id, ) response = model.chat( tokenizer=tokenizer, pixel_values=pixel_values, question=question, generation_config=generation_config, verbose=True ) response = post_processing(response) print(img, question, gt, response, sep='\t', file=fout) fin.close() fout.close()