# -*- coding: utf-8 -*- """OFA.ipynb Original file is located at(OFA/colab.md -> Image Captioning) https://colab.research.google.com/drive/1jogyZ-2rdHU3XxZOf3TBfhex1XHqX-1m # **OFA** You can use different instructions to perform various tasks (i.e., image captioning, visual grounding, VQA and grounded captioning) with just one model. """ """## **Preparation** Below you just need to import required packages, and check whether to use GPU or FP16. """ import torch import numpy as np from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.dataclass.utils import convert_namespace_to_omegaconf from tasks.mm_tasks.refcoco import RefcocoTask#Missing zip file bug fix: python3 -c "import nltk; nltk.download('averaged_perceptron_tagger')" from models.ofa import OFAModel from PIL import Image import time tasks.register_task('refcoco', RefcocoTask) # turn on cuda if GPU is available use_cuda = torch.cuda.is_available() # use fp16 only when GPU is available use_fp16 = False # specify some options for evaluation parser = options.get_generation_parser() input_args = ["", "--task=refcoco", "--beam=10", "--path=checkpoints/ofa_large.pt", "--bpe-dir=utils/BPE", "--no-repeat-ngram-size=3", "--patch-image-size=384"] args = options.parse_args_and_arch(parser, input_args) cfg = convert_namespace_to_omegaconf(args) """## **Build Model** Below you can build your model and load the weights from the given checkpoint, and also build a generator. """ # Load pretrained ckpt & config task = tasks.setup_task(cfg.task) models, cfg = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), task=task ) # Move models to GPU for model in models: model.eval() if use_fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) # Initialize generator generator = task.build_generator(models, cfg.generation) """## **Preprocess** We demonstrate the required transformation fucntions for preprocessing inputs. """ # Image transform from torchvision import transforms mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] patch_resize_transform = transforms.Compose([ lambda image: image.convert("RGB"), transforms.Resize((task.cfg.patch_image_size, task.cfg.patch_image_size), interpolation=Image.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) # Text preprocess bos_item = torch.LongTensor([task.src_dict.bos()]) eos_item = torch.LongTensor([task.src_dict.eos()]) pad_idx = task.src_dict.pad() def get_symbols_to_strip_from_output(generator): if hasattr(generator, "symbols_to_strip_from_output"): return generator.symbols_to_strip_from_output else: return {generator.bos, generator.eos} def decode_fn(x, tgt_dict, bpe, generator, tokenizer=None): x = tgt_dict.string(x.int().cpu(), extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator)) token_result = [] bin_result = [] img_result = [] for token in x.strip().split(): if token.startswith('".format(int(round(coord_list[0] * w_resize_ratio / task.cfg.max_image_size * (task.cfg.num_bins - 1))))] bin_list += ["".format(int(round(coord_list[1] * h_resize_ratio / task.cfg.max_image_size * (task.cfg.num_bins - 1))))] bin_list += ["".format(int(round(coord_list[2] * w_resize_ratio / task.cfg.max_image_size * (task.cfg.num_bins - 1))))] bin_list += ["".format(int(round(coord_list[3] * h_resize_ratio / task.cfg.max_image_size * (task.cfg.num_bins - 1))))] return ' '.join(bin_list) def bin2coord(bins, w_resize_ratio, h_resize_ratio): bin_list = [int(bin[5:-1]) for bin in bins.strip().split()] coord_list = [] coord_list += [bin_list[0] / (task.cfg.num_bins - 1) * task.cfg.max_image_size / w_resize_ratio] coord_list += [bin_list[1] / (task.cfg.num_bins - 1) * task.cfg.max_image_size / h_resize_ratio] coord_list += [bin_list[2] / (task.cfg.num_bins - 1) * task.cfg.max_image_size / w_resize_ratio] coord_list += [bin_list[3] / (task.cfg.num_bins - 1) * task.cfg.max_image_size / h_resize_ratio] return coord_list def encode_text(text, length=None, append_bos=False, append_eos=False): line = [ task.bpe.encode(' {}'.format(word.strip())) if not word.startswith('