"vscode:/vscode.git/clone" did not exist on "537b7eb6843490c2f18de32d61240946902dbab1"
Commit bc5ebf0f authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #2167 canceled with stages
.. role:: hidden
:class: hidden-section
.. currentmodule:: {{ module }}
{{ name | underline}}
.. autoclass:: {{ name }}
:members:
..
autogenerated from _templates/autosummary/class.rst
note it does not have :inherited-members:
.. role:: hidden
:class: hidden-section
.. currentmodule:: {{ module }}
{{ name | underline}}
.. autoclass:: {{ name }}
:members:
:special-members: __call__
..
autogenerated from _templates/callable.rst
note it does not have :inherited-members:
# flake8: noqa
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import ast
import subprocess
import sys
import pytorch_sphinx_theme
from sphinx.builders.html import StandaloneHTMLBuilder
sys.path.insert(0, os.path.abspath('../../'))
# -- Project information -----------------------------------------------------
project = 'VLMEvalKit'
copyright = '2023, VLMEvalKit'
author = 'VLMEvalKit Authors'
# The full version, including alpha/beta/rc tags
version_file = '../../vlmeval/__init__.py'
def get_version():
with open(version_file, 'r') as f:
file_content = f.read()
# Parse the file content into an abstract syntax tree (AST)
tree = ast.parse(file_content, filename=version_file)
# Iterate through the body of the AST, looking for an assignment to __version__
for node in tree.body:
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == '__version__':
return node.value.s
raise ValueError('__version__ not found')
release = get_version()
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'myst_parser',
'sphinx_copybutton',
'sphinx_tabs.tabs',
'notfound.extension',
'sphinxcontrib.jquery',
'sphinx_design',
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = {
'.rst': 'restructuredtext',
'.md': 'markdown',
}
language = 'cn'
# The master toctree document.
root_doc = 'index'
html_context = {
'github_version': 'latest',
}
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'pytorch_sphinx_theme'
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
# yapf: disable
html_theme_options = {
'menu': [
{
'name': 'GitHub',
'url': 'https://github.com/open-compass/VLMEvalKit'
},
],
# Specify the language of shared menu
'menu_lang': 'cn',
# Disable the default edit on GitHub
'default_edit_on_github': False,
}
# yapf: enable
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_css_files = [
'https://cdn.datatables.net/v/bs4/dt-1.12.1/datatables.min.css',
'css/readthedocs.css'
]
html_js_files = [
'https://cdn.datatables.net/v/bs4/dt-1.12.1/datatables.min.js',
'js/custom.js'
]
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'vlmevalkitdoc'
# -- Options for LaTeX output ------------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(root_doc, 'vlmevalkit.tex', 'VLMEvalKit Documentation', author,
'manual'),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [(root_doc, 'vlmevalkit', 'VLMEvalKit Documentation', [author],
1)]
# -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(root_doc, 'vlmevalkit', 'VLMEvalKit Documentation', author,
'VLMEvalKit Authors', 'AGI evaluation toolbox and benchmark.',
'Miscellaneous'),
]
# -- Options for Epub output -------------------------------------------------
# Bibliographic Dublin Core info.
epub_title = project
# The unique identifier of the text. This can be a ISBN number
# or the project homepage.
#
# epub_identifier = ''
# A unique identification for the text.
#
# epub_uid = ''
# A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html']
# set priority when building html
StandaloneHTMLBuilder.supported_image_types = [
'image/svg+xml', 'image/gif', 'image/png', 'image/jpeg'
]
# -- Extension configuration -------------------------------------------------
# Ignore >>> when copying code
copybutton_prompt_text = r'>>> |\.\.\. '
copybutton_prompt_is_regexp = True
# Auto-generated header anchors
myst_heading_anchors = 3
# Enable "colon_fence" extension of myst.
myst_enable_extensions = ['colon_fence', 'dollarmath']
# Configuration for intersphinx
intersphinx_mapping = {
'python': ('https://docs.python.org/3', None),
'numpy': ('https://numpy.org/doc/stable', None),
'torch': ('https://pytorch.org/docs/stable/', None),
'mmengine': ('https://mmengine.readthedocs.io/en/latest/', None),
'transformers':
('https://huggingface.co/docs/transformers/main/en/', None),
}
napoleon_custom_sections = [
# Custom sections for data elements.
('Meta fields', 'params_style'),
('Data fields', 'params_style'),
]
# Disable docstring inheritance
autodoc_inherit_docstrings = False
# Mock some imports during generate API docs.
autodoc_mock_imports = ['rich', 'attr', 'einops']
# Disable displaying type annotations, these can be very verbose
autodoc_typehints = 'none'
# The not found page
notfound_template = '404.html'
def builder_inited_handler(app):
subprocess.run(['./cp_origin_docs.sh'])
def setup(app):
app.connect('builder-inited', builder_inited_handler)
#!/usr/bin/env bash
# Copy *.md files from docs/ if it doesn't have a Chinese translation
for filename in $(find ../en/ -name '*.md' -printf "%P\n");
do
mkdir -p $(dirname $filename)
cp -n ../en/$filename ./$filename
done
[html writers]
table_style: colwidths-auto
欢迎来到 VLMEvalKit 中文教程!
==========================================
VLMEvalKit 上手路线
-------------------------------
为了用户能够快速上手,我们推荐以下流程:
- 对于想要使用 VLMEvalKit 的用户,我们推荐先阅读 开始你的第一步_ 部分来设置环境,并启动一个迷你实验熟悉流程。
- 若您想进行更多模块的自定义,例如增加数据集和模型,我们提供了 进阶教程_ 。
我们始终非常欢迎用户的 PRs 和 Issues 来完善 VLMEvalKit!
.. _快速开始:
.. toctree::
:maxdepth: 1
:caption: 快速开始
Quickstart.md
.. .. _教程:
.. .. toctree::
.. :maxdepth: 1
.. :caption: 教程
.. user_guides/framework_overview.md
.. _进阶教程:
.. toctree::
:maxdepth: 1
:caption: 进阶教程
Development.md
ConfigSystem.md
.. .. _其他说明:
.. .. toctree::
.. :maxdepth: 1
.. :caption: 其他说明
.. notes/contribution_guide.md
索引与表格
==================
* :ref:`genindex`
* :ref:`search`
decord; platform_machine != 'arm64'
eva-decord; platform_machine == 'arm64'
gradio
huggingface_hub
imageio
matplotlib
numpy
omegaconf
openai
opencv-python>=4.4.0.46
openpyxl
pandas
pillow
portalocker
protobuf
python-dotenv
requests
rich
sentencepiece
setuptools
sty
tabulate
tiktoken
timeout-decorator
torch
tqdm
transformers
typing_extensions
validators
xlsxwriter
docutils==0.18.1
modelindex
myst-parser
-e git+https://github.com/open-compass/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
sphinx==6.1.3
sphinx-copybutton
sphinx-design
sphinx-notfound-page
sphinx-tabs
sphinxcontrib-jquery
tabulate
This diff is collapsed.
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os, cv2\n",
"import string\n",
"import os.path as osp\n",
"import numpy as np\n",
"from collections import defaultdict\n",
"from vlmeval.smp import ls, load, dump, download_file, encode_image_file_to_base64, md5, mrlines\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import multiprocessing as mp\n",
"from PIL import Image, ImageFont, ImageDraw\n",
"\n",
"font_URL = 'http://opencompass.openxlab.space/utils/Fonts/timesb.ttf'\n",
"font_file = 'timesb.ttf'\n",
"if not osp.exists(font_file):\n",
" download_file(font_URL)\n",
" \n",
"test_split_URL = 'https://s3-us-east-2.amazonaws.com/prior-datasets/ai2d_test_ids.csv'\n",
"test_split_file = 'ai2d_test_ids.csv'\n",
"if not osp.exists(test_split_file):\n",
" download_file(test_split_URL)\n",
" \n",
"test_ids = set(mrlines(test_split_file))\n",
" \n",
"def proper_font_size(font_file, wh, text, ratio=1):\n",
" font_size = 2\n",
" while True:\n",
" font = ImageFont.truetype(font_file, font_size)\n",
" real_box = font.getbbox(text)\n",
" real_wh = (real_box[2] - real_box[0], real_box[3] - real_box[1])\n",
" if real_wh[0] > wh[0] * ratio or real_wh[1] > wh[1] * ratio:\n",
" break\n",
" font_size += 1\n",
" return font_size\n",
"\n",
"def cover_image(ann_path):\n",
" data = load(ann_path)\n",
" texts = list(data['text'].values())\n",
" raw_img = ann_path.replace('annotations', 'images').replace('.json', '')\n",
" tgt_img = raw_img.replace('images', 'images_abc')\n",
" img = Image.open(raw_img)\n",
" draw = ImageDraw.Draw(img)\n",
" for text in texts:\n",
" st, ed = tuple(text['rectangle'][0]), tuple(text['rectangle'][1])\n",
" T = text['replacementText']\n",
" draw.rectangle((st, ed), fill='white')\n",
" font_size = proper_font_size(font_file, (ed[0] - st[0], ed[1] - st[1]), T, ratio=1)\n",
" font = ImageFont.truetype(font_file, font_size)\n",
" text_box = font.getbbox(T)\n",
" text_wh = (text_box[2] - text_box[0], text_box[3] - text_box[1])\n",
" cx, cy = (st[0] + ed[0]) // 2, st[1]\n",
" stx = cx - text_wh[0] // 2\n",
" sty = cy - text_wh[1] // 2\n",
" draw.text((stx, sty), T, font=font, fill='black')\n",
" img.save(tgt_img) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Process for no mask images\n",
"test_ids = set(mrlines(test_split_file))\n",
"\n",
"def detect_image_color(image):\n",
" gray_image = image.convert('L')\n",
" mean_brightness = np.mean(np.array(gray_image))\n",
" if mean_brightness < 127:\n",
" return 'white'\n",
" else:\n",
" return 'black'\n",
"\n",
"def cover_image(ann_path):\n",
" data = load(ann_path)\n",
" texts = list(data['text'].values())\n",
" raw_img = ann_path.replace('annotations', 'images').replace('.json', '')\n",
" tgt_img = raw_img.replace('images', 'images_abc')\n",
" img = Image.open(raw_img)\n",
" draw = ImageDraw.Draw(img)\n",
" color = detect_image_color(img)\n",
" font_size = 0\n",
" for text in texts:\n",
" st, ed = tuple(text['rectangle'][0]), tuple(text['rectangle'][1])\n",
" font_size += (ed[1] - st[1])\n",
" if len(texts) != 0:\n",
" font_size /= len(texts)\n",
" else:\n",
" font_size = 2\n",
" for text in texts:\n",
" st, ed = tuple(text['rectangle'][0]), tuple(text['rectangle'][1])\n",
" T = text['replacementText']\n",
" for i in range(2):\n",
" draw.rectangle(\n",
" [(st[0] - i, st[1] - i), (ed[0] + i, ed[1] + i)],\n",
" outline=color\n",
" )\n",
" font = ImageFont.truetype(font_file, font_size)\n",
" text_box = font.getbbox(T)\n",
" text_wh = (text_box[2] - text_box[0], text_box[3] - text_box[1])\n",
" cx, cy = (st[0] + ed[0]) // 2, st[1]\n",
" stx = cx - text_wh[0] // 2\n",
" sty = cy - text_wh[1] * 1.5\n",
" if sty < 0:\n",
" sty = cy + text_wh[1] * 1.3\n",
" draw.text((stx, sty), T, font=font, fill=color)\n",
" img.save(tgt_img) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"download_file('https://ai2-public-datasets.s3.amazonaws.com/diagrams/ai2d-all.zip')\n",
"os.system('unzip -o ai2d-all.zip')\n",
"\n",
"images = ls('ai2d/images/')\n",
"questions = ls('ai2d/questions/')\n",
"annotations = ls('ai2d/annotations/')\n",
"cates = load('ai2d/categories.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pool = mp.Pool(32)\n",
"pool.map(cover_image, annotations)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def puncproc(inText):\n",
" import re\n",
" outText = inText\n",
" punct = [\n",
" ';', r'/', '[', ']', '\"', '{', '}', '(', ')', '=', '+', '\\\\', '_', '-',\n",
" '>', '<', '@', '`', ',', '?', '!'\n",
" ]\n",
" commaStrip = re.compile('(\\d)(,)(\\d)') # noqa: W605\n",
" periodStrip = re.compile('(?!<=\\d)(\\.)(?!\\d)') # noqa: W605\n",
" for p in punct:\n",
" if (p + ' ' in inText or ' ' + p in inText) or (re.search(commaStrip, inText) is not None):\n",
" outText = outText.replace(p, '')\n",
" else:\n",
" outText = outText.replace(p, ' ')\n",
" outText = periodStrip.sub('', outText, re.UNICODE)\n",
" return outText\n",
"\n",
"def check_choices(line):\n",
" def ischar(s):\n",
" s = str(s)\n",
" if s in ['{}', 'Both', 'None of above']:\n",
" return True\n",
" elif s.startswith('Stage ') and ischar(s[6:]):\n",
" return True\n",
" elif ' and ' in s and np.all([ischar(x) for x in s.split(' and ')]):\n",
" return True\n",
" elif len(s) <= 2:\n",
" return True\n",
" elif len(puncproc(s).split()) > 1:\n",
" return np.all([ischar(x) for x in puncproc(s).split()])\n",
" return False\n",
" n_char = sum([ischar(line[x]) for x in 'ABCD'])\n",
" return n_char >= 3\n",
"\n",
"def check_question(question):\n",
" words = puncproc(question).split()\n",
" for ch in string.ascii_lowercase + string.ascii_uppercase:\n",
" if ch in words:\n",
" return True\n",
" return False\n",
"\n",
"def is_abc(abc, choices, question):\n",
" if abc == 0:\n",
" return False\n",
" if check_choices(choices):\n",
" return True\n",
" if check_question(question):\n",
" return True\n",
" return False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_all = defaultdict(list)\n",
"for qfile in questions:\n",
" data = load(qfile)\n",
" idx = data['imageName'].split('.')[0]\n",
" if idx not in test_ids:\n",
" continue\n",
" image_pth = qfile.replace('questions', 'images').replace('.json', '')\n",
" cate = cates[image_pth.split('/')[-1]]\n",
" for q, qmeta in data['questions'].items():\n",
" assert '.png-' in qmeta['questionId']\n",
" main, sub = qmeta['questionId'].split('.png-')\n",
" idx = int(main) * 100 + int(sub)\n",
" \n",
" answers = qmeta['answerTexts']\n",
" correct = qmeta['correctAnswer']\n",
" \n",
" data_all['index'].append(idx)\n",
" data_all['question'].append(q)\n",
" assert len(answers) == 4\n",
" for c, a in zip('ABCD', answers):\n",
" data_all[c].append(a)\n",
" data_all['answer'].append('ABCD'[qmeta['correctAnswer']])\n",
" data_all['category'].append(cate)\n",
" data_all['abcLabel'].append(qmeta['abcLabel'])\n",
" abc = is_abc(qmeta['abcLabel'], {x: data_all[x][-1] for x in 'ABCD'}, q)\n",
" # if qmeta['abcLabel'] and not abc:\n",
" # print(qmeta['abcLabel'], {x: data_all[x][-1] for x in 'ABCD'}, q)\n",
" data_all['image_path'].append(image_pth.replace('images', 'images_abc') if abc else image_pth)\n",
"data = pd.DataFrame(data_all)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"images = []\n",
"image_seen = {}\n",
"for idx, pth in zip(data['index'], data['image_path']):\n",
" images.append(encode_image_file_to_base64(pth))\n",
"\n",
"data['image'] = images\n",
"dump(data, 'AI2D_TEST.tsv')\n",
"print(md5('AI2D_TEST.tsv'))"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import sys
from vlmeval import *
from vlmeval.dataset import SUPPORTED_DATASETS
FAIL_MSG = 'Failed to obtain answer via API.'
root = sys.argv[1]
if root[-1] in '/\\':
root = root[:-1]
model_name = root.split('/')[-1]
for d in SUPPORTED_DATASETS:
fname = f'{model_name}_{d}.xlsx'
pth = osp.join(root, fname)
if osp.exists(pth):
data = load(pth)
# Detect Failure
assert 'prediction' in data
data['prediction'] = [str(x) for x in data['prediction']]
fail = [FAIL_MSG in x for x in data['prediction']]
if sum(fail):
nfail = sum(fail)
ntot = len(fail)
print(f'Model {model_name} x Dataset {d}: {nfail} out of {ntot} failed. {nfail / ntot * 100: .2f}%. ')
eval_files = ls(root, match=f'{model_name}_{d}_')
eval_files = [x for x in eval_files if listinstr([f'{d}_openai', f'{d}_gpt'], x) and x.endswith('.xlsx')]
if len(eval_files) == 0:
print(f'Model {model_name} x Dataset {d} openai missing')
continue
assert len(eval_files) == 1
eval_file = eval_files[0]
data = load(eval_file)
if 'MMVet' in d:
bad = [x for x in data['log'] if 'All 5 retries failed.' in str(x)]
if len(bad):
print(f'Model {model_name} x Dataset {d} Evaluation: {len(bad)} out of {len(data)} failed.')
elif 'MathVista' in d:
bad = [x for x in data['res'] if FAIL_MSG in str(x)]
if len(bad):
print(f'Model {model_name} x Dataset {d} Evaluation: {len(bad)} out of {len(data)} failed.')
elif d == 'LLaVABench':
sub = data[data['gpt4_score'] == -1]
sub = sub[sub['gpt4_score'] == -1]
if len(sub):
print(f'Model {model_name} x Dataset {d} Evaluation: {len(sub)} out of {len(data)} failed.')
else:
bad = [x for x in data['log'] if FAIL_MSG in str(x)]
if len(bad):
print(f'Model {model_name} x Dataset {d} Evaluation: {len(bad)} out of {len(data)} failed.')
\ No newline at end of file
import argparse
from vlmeval.smp import *
from vlmeval.config import supported_VLM
def is_api(x):
return getattr(supported_VLM[x].func, 'is_api', False)
models = list(supported_VLM)
models = [x for x in models if 'fs' not in x]
models = [x for x in models if not is_api(x)]
exclude_list = ['cogvlm-grounding-generalist', 'emu2']
models = [x for x in models if x not in exclude_list]
def is_large(x):
return '80b' in x or 'emu2' in x or '34B' in x
small_models = [x for x in models if not is_large(x)]
large_models = [x for x in models if is_large(x)]
models = small_models + large_models
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, nargs='+', required=True)
args = parser.parse_args()
# Skip some models
models = [x for x in models if not listinstr(['MiniGPT', 'grounding-generalist'], x)]
for m in models:
unknown_datasets = [x for x in args.data if not osp.exists(f'{m}/{m}_{x}.xlsx')]
if len(unknown_datasets) == 0:
continue
dataset_str = ' '.join(unknown_datasets)
if '80b' in m:
cmd = f'python run.py --data {dataset_str} --model {m}'
else:
cmd = f'bash run.sh --data {dataset_str} --model {m}'
print(cmd)
os.system(cmd)
\ No newline at end of file
#!/bin/bash
DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
cp $DIR/../config.py $DIR/../vlmeval/
cp $DIR/../misc/* $DIR/../vlmeval/vlm/misc/
\ No newline at end of file
"""
pip install gradio # proxy_on first
python vis_geochat_data.py
# browse data in http://127.0.0.1:10064
"""
import os
import io
import json
import copy
import time
import gradio as gr
import base64
from PIL import Image
from io import BytesIO
from argparse import Namespace
# from llava import conversation as conversation_lib
from typing import Sequence
from vlmeval import *
from vlmeval.dataset import SUPPORTED_DATASETS, build_dataset
SYS = "You are a helpful assistant. Your job is to faithfully translate all provided text into Chinese faithfully. "
# Translator = SiliconFlowAPI(model='Qwen/Qwen2.5-7B-Instruct', system_prompt=SYS)
Translator = OpenAIWrapper(model='gpt-4o-mini', system_prompt=SYS)
def image_to_mdstring(image):
return f"![image](data:image/jpeg;base64,{image})"
def images_to_md(images):
return '\n\n'.join([image_to_mdstring(image) for image in images])
def mmqa_display(question, target_size=768):
question = {k.lower() if len(k) > 1 else k: v for k, v in question.items()}
keys = list(question.keys())
keys = [k for k in keys if k not in ['index', 'image']]
idx = question.pop('index', 'XXX')
text = f'\n- INDEX: {idx}\n'
images = question.pop('image')
if images[0] == '[' and images[-1] == ']':
images = eval(images)
else:
images = [images]
qtext = question.pop('question', None)
if qtext is not None:
text += f'- QUESTION: {qtext}\n'
if 'A' in question:
text += f'- Choices: \n'
for k in string.ascii_uppercase:
if k in question:
text += f'\t-{k}: {question.pop(k)}\n'
answer = question.pop('answer', None)
for k in question:
if not pd.isna(question[k]):
text += f'- {k.upper()}. {question[k]}\n'
if answer is not None:
text += f'- ANSWER: {answer}\n'
image_md = images_to_md(images)
return text, image_md
def parse_args():
parser = argparse.ArgumentParser()
# Essential Args, Setting the Names of Datasets and Models
parser.add_argument('--port', type=int, default=7860)
args = parser.parse_args()
return args
def gradio_app_vis_dataset(port=7860):
data, loaded_obj = None, {}
def btn_submit_click(filename, ann_id):
if filename not in loaded_obj:
return filename_change(filename, ann_id)
nonlocal data
data_desc = gr.Markdown(f'Visualizing {filename}, {len(data)} samples in total. ')
if ann_id < 0 or ann_id >= len(data):
return filename, ann_id, data_desc, gr.Markdown('Invalid Index'), gr.Markdown(f'Index out of range [0, {len(data) - 1}]')
item = data.iloc[ann_id]
text, image_md = mmqa_display(item)
return filename, ann_id, data_desc, image_md, text
def btn_next_click(filename, ann_id):
return btn_submit_click(filename, ann_id + 1)
# def translate_click(anno_en):
# return gr.Markdown(Translator.generate(anno_en))
def filename_change(filename, ann_id):
nonlocal data, loaded_obj
def legal_filename(filename):
LMURoot = LMUDataRoot()
if filename in SUPPORTED_DATASETS:
return build_dataset(filename).data
elif osp.exists(filename):
data = load(filename)
assert 'index' in data and 'image' in data
image_map = {i: image for i, image in zip(data['index'], data['image'])}
for k, v in image_map.items():
if (not isinstance(v, str) or len(v) < 64) and v in image_map:
image_map[k] = image_map[v]
data['image'] = [image_map[k] for k in data['index']]
return data
elif osp.exists(osp.join(LMURoot, filename)):
filename = osp.join(LMURoot, filename)
return legal_filename(filename)
else:
return None
data = legal_filename(filename)
if data is None:
return filename, 0, gr.Markdown(''), gr.Markdown("File not found"), gr.Markdown("File not found")
loaded_obj[filename] = data
return btn_submit_click(filename, 0)
with gr.Blocks() as app:
filename = gr.Textbox(
value='Dataset Name (supported by VLMEvalKit) or TSV FileName (Relative under `LMURoot` or Real Path)',
label='Dataset',
interactive=True,
visible=True)
with gr.Row():
ann_id = gr.Number(0, label='Sample Index (Press Enter)', interactive=True, visible=True)
btn_next = gr.Button("Next")
# btn_translate = gr.Button('CN Translate')
with gr.Row():
data_desc = gr.Markdown('Dataset Description', label='Dataset Description')
with gr.Row():
image_output = gr.Markdown('Image PlaceHolder', label='Image Visualization')
anno_en = gr.Markdown('Image Annotation', label='Image Annotation')
# anno_cn = gr.Markdown('Image Annotation (Chinese)', label='Image Annotation (Chinese)')
input_components = [filename, ann_id]
all_components = [filename, ann_id, data_desc, image_output, anno_en]
filename.submit(filename_change, input_components, all_components)
ann_id.submit(btn_submit_click, input_components, all_components)
btn_next.click(btn_next_click, input_components, all_components)
# btn_translate.click(translate_click, anno_en, anno_cn)
# app.launch()
app.launch(server_name='0.0.0.0', debug=True, show_error=True, server_port=port)
if __name__ == "__main__":
args = parse_args()
gradio_app_vis_dataset(port=args.port)
from vlmeval.smp import *
from vlmeval.tools import EVAL
import gradio as gr
HEADER = """
# Welcome to MMBench👏👏
We are delighted that you are willing to submit the evaluation results to the MMBench official website! The evaluation service currently can handle submissions of MMBench, MMBench-CN, and CCBench. We use `gpt-3.5-turbo-0125` to help answer matching. Evaluation Codes in VLMEvalKit: https://github.com/open-compass/VLMEvalKit. Please adopt / follow the implementation of VLMEvalKit to generate the submission files.
The evaluation script is available at https://github.com/open-compass/VLMEvalKit/tree/main/scripts/mmb_eval_gradio.py
Please contact `opencompass@pjlab.org.cn` for any inquirys about this script.
"""
def upload_file(file):
file_path = file.name
return file_path
def prepare_file(file_name):
file_md5 = md5(file_name)
root = LMUDataRoot()
root = osp.join(root, 'eval_server')
os.makedirs(root, exist_ok=True)
suffix = file_name.split('.')[-1]
if suffix not in ['xlsx', 'tsv', 'csv']:
return False, "Please submit a file that ends with `.xlsx`, `.tsv`, or `.csv`"
new_file_name = osp.join(root, f'{file_md5}.{suffix}')
shutil.move(file_name, new_file_name)
eval_file = new_file_name
try:
data = load(eval_file)
except:
return False, "Your excel file can not be successfully loaded by `pd.read_excel`, please double check and submit again. "
for k in data.keys():
data[k.lower() if k not in 'ABCD' else k] = data.pop(k)
if "index" not in data:
return False, "Your excel file should have a column named `index`, please double check and submit again" , {}
if "prediction" not in data:
return False, "Your excel file should have a column named `prediction`, please double check and submit again" , {}
for ch in 'ABCD':
if ch not in data:
return False, f"Your excel file should have a column named `{ch}`, please double check and submit again" , {}
dump(data, eval_file)
return True, eval_file
def determine_dataset(eval_file):
data = load(eval_file)
def cn_ratio(data):
iscn = [cn_string(x) for x in data['question']]
return np.mean(iscn)
max_ind = np.max([int(x) for x in data['index'] if int(x) < 1e5])
if max_ind < 1000 and 'l2-category' not in data:
return 'CCBench' if cn_ratio(data) > 0.5 else "Unknown"
elif max_ind < 3000 :
return 'MMBench_CN' if cn_ratio(data) > 0.5 else "MMBench"
else:
return 'MMBench_CN_V11' if cn_ratio(data) > 0.5 else "MMBench_V11"
def reformat_acc(acc):
splits = set(acc['split'])
keys = list(acc.keys())
keys.remove('split')
nacc = {'Category': []}
for sp in splits:
nacc[sp.upper()] = []
for k in keys:
nacc['Category'].append(k)
for sp in splits:
nacc[sp.upper()].append(acc[acc['split'] == sp].iloc[0][k] * 100)
return pd.DataFrame(nacc)
def evaluate(file):
file_name = file.name
flag, eval_file = prepare_file(file_name)
if not flag:
return "Error: " + eval_file
dataset = determine_dataset(eval_file)
if dataset == 'Unknown':
return "Error: Cannot determine the dataset given your submitted file. "
eval_id = eval_file.split('/')[-1].split('.')[0]
ret = f"Evaluation ID: {eval_id}\n"
timestamp = datetime.datetime.now().strftime('%Y.%m.%d %H:%M:%S')
ret += f'Evaluation Timestamp: {timestamp}\n'
acc = EVAL(dataset, eval_file)
nacc = reformat_acc(acc).round(1)
return ret, nacc
with gr.Blocks() as demo:
gr.Markdown(HEADER)
file_output = gr.File()
upload_button = gr.UploadButton("Click to upload you prediction files for a supported benchmark")
upload_button.upload(upload_file, upload_button, file_output)
btn = gr.Button("🚀 Evaluate")
eval_log = gr.Textbox(label="Evaluation Log", placeholder="Your evaluation log will be displayed here")
df_empty = pd.DataFrame([], columns=['Evaluation Result'])
eval_result = gr.components.DataFrame(value=df_empty)
btn.click(evaluate, inputs=[file_output], outputs=[eval_log, eval_result])
if __name__ == '__main__':
demo.launch(server_name='0.0.0.0', debug=True, show_error=True)
\ No newline at end of file
#!/bin/bash
set -x
export GPU=$(nvidia-smi --list-gpus | wc -l)
torchrun --nproc-per-node=$GPU run.py ${@:1}
\ No newline at end of file
#!/bin/bash
set -x
srun -n1 --ntasks-per-node=1 --partition $1 --gres=gpu:8 --quotatype=reserved --job-name vlmeval --cpus-per-task=64 torchrun --nproc-per-node=8 run.py ${@:2}
\ No newline at end of file
from vlmeval.smp import *
from vlmeval.dataset import SUPPORTED_DATASETS
def get_score(model, dataset):
file_name = f'{model}/{model}_{dataset}'
if listinstr([
'CCBench', 'MMBench', 'SEEDBench_IMG', 'MMMU', 'ScienceQA',
'AI2D_TEST', 'MMStar', 'RealWorldQA', 'BLINK', 'VisOnlyQA-VLMEvalKit'
], dataset):
file_name += '_acc.csv'
elif listinstr(['MME', 'Hallusion', 'LLaVABench'], dataset):
file_name += '_score.csv'
elif listinstr(['MMVet', 'MathVista'], dataset):
file_name += '_gpt-4-turbo_score.csv'
elif listinstr(['COCO', 'OCRBench'], dataset):
file_name += '_score.json'
else:
raise NotImplementedError
if not osp.exists(file_name):
return {}
data = load(file_name)
ret = {}
if dataset == 'CCBench':
ret[dataset] = data['Overall'][0] * 100
elif dataset == 'MMBench':
for n, a in zip(data['split'], data['Overall']):
if n == 'dev':
ret['MMBench_DEV_EN'] = a * 100
elif n == 'test':
ret['MMBench_TEST_EN'] = a * 100
elif dataset == 'MMBench_CN':
for n, a in zip(data['split'], data['Overall']):
if n == 'dev':
ret['MMBench_DEV_CN'] = a * 100
elif n == 'test':
ret['MMBench_TEST_CN'] = a * 100
elif listinstr(['SEEDBench', 'ScienceQA', 'MMBench', 'AI2D_TEST', 'MMStar', 'RealWorldQA', 'BLINK'], dataset):
ret[dataset] = data['Overall'][0] * 100
elif 'MME' == dataset:
ret[dataset] = data['perception'][0] + data['reasoning'][0]
elif 'MMVet' == dataset:
data = data[data['Category'] == 'Overall']
ret[dataset] = float(data.iloc[0]['acc'])
elif 'HallusionBench' == dataset:
data = data[data['split'] == 'Overall']
for met in ['aAcc', 'qAcc', 'fAcc']:
ret[dataset + f' ({met})'] = float(data.iloc[0][met])
elif 'MMMU' in dataset:
data = data[data['split'] == 'validation']
ret['MMMU (val)'] = float(data.iloc[0]['Overall']) * 100
elif 'MathVista' in dataset:
data = data[data['Task&Skill'] == 'Overall']
ret[dataset] = float(data.iloc[0]['acc'])
elif 'LLaVABench' in dataset:
data = data[data['split'] == 'overall'].iloc[0]
ret[dataset] = float(data['Relative Score (main)'])
elif 'OCRBench' in dataset:
ret[dataset] = data['Final Score']
elif dataset == 'VisOnlyQA-VLMEvalKit':
for n, a in zip(data['split'], data['Overall']):
ret[f'VisOnlyQA-VLMEvalKit_{n}'] = a * 100
return ret
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, nargs='+', default=[])
parser.add_argument("--model", type=str, nargs='+', required=True)
args = parser.parse_args()
return args
def gen_table(models, datasets):
res = defaultdict(dict)
for m in models:
for d in datasets:
try:
res[m].update(get_score(m, d))
except Exception as e:
logging.warning(f'{type(e)}: {e}')
logging.warning(f'Missing Results for Model {m} x Dataset {d}')
keys = []
for m in models:
for d in res[m]:
keys.append(d)
keys = list(set(keys))
keys.sort()
final = defaultdict(list)
for m in models:
final['Model'].append(m)
for k in keys:
if k in res[m]:
final[k].append(res[m][k])
else:
final[k].append(None)
final = pd.DataFrame(final)
dump(final, 'summ.csv')
if len(final) >= len(final.iloc[0].keys()):
print(tabulate(final))
else:
print(tabulate(final.T))
if __name__ == '__main__':
args = parse_args()
if args.data == []:
args.data = list(SUPPORTED_DATASETS)
gen_table(args.model, args.data)
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import copy as cp\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.font_manager as fm\n",
"\n",
"def download_file(url, filename=None):\n",
" from urllib.request import urlretrieve\n",
" if filename is None:\n",
" filename = url.split('/')[-1]\n",
" urlretrieve(url, filename)\n",
"\n",
"font_URL = 'http://opencompass.openxlab.space/utils/Fonts/segoepr.ttf'\n",
"download_file(font_URL)\n",
"\n",
"font12 = fm.FontProperties(fname='segoepr.ttf', size=12)\n",
"font15 = fm.FontProperties(fname='segoepr.ttf', size=15, weight='bold')\n",
"font18 = fm.FontProperties(fname='segoepr.ttf', size=18, weight='bold')\n",
"\n",
"DATA_URL = 'http://opencompass.openxlab.space/utils/OpenVLM.json'\n",
"download_file(DATA_URL)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def pre_normalize(raw_data, labels):\n",
" data_list = cp.deepcopy(raw_data)\n",
" minimum, maximum, max_range, range_map = {}, {}, 0, {}\n",
" for lb in labels:\n",
" minimum[lb] = min([x[lb] for x in data_list])\n",
" maximum[lb] = max([x[lb] for x in data_list])\n",
" max_range = max(max_range, maximum[lb] - minimum[lb])\n",
" max_range *= 1.25\n",
" for lb in labels:\n",
" mid = (minimum[lb] + maximum[lb]) / 2\n",
" new_range = (mid - max_range / 2, mid + max_range / 2) if (mid + max_range / 2) < 100 else (100 - max_range, 100)\n",
" range_map[lb] = new_range\n",
" for item in data_list:\n",
" assert new_range[0] <= item[lb] <= new_range[1]\n",
" item[lb] = (item[lb] - new_range[0]) / max_range * 100\n",
" return data_list, range_map\n",
"\n",
"# solve the problem that some benchmark score is too high and out of range\n",
"def log_normalize(raw_data, labels):\n",
" data_list = cp.deepcopy(raw_data)\n",
" minimum, maximum, max_range, range_map = {}, {}, 0, {}\n",
" for lb in labels:\n",
" minimum[lb] = min([np.log(x[lb]) for x in data_list])\n",
" maximum[lb] = max([np.log(x[lb]) for x in data_list])\n",
" max_range = max(max_range, maximum[lb] - minimum[lb])\n",
" max_range *= 1.005\n",
" for lb in labels:\n",
" mid = (minimum[lb] + maximum[lb]) / 2\n",
" new_range = (mid - max_range / 2, mid + max_range / 2) if (mid + max_range / 2) < 100 else (100 - max_range, 100)\n",
" range_map[lb] = new_range\n",
" for item in data_list:\n",
" assert new_range[0] <= np.log(item[lb]) <= new_range[1]\n",
" item[lb] = (np.log(item[lb]) - new_range[0]) / max_range * 100\n",
" return data_list, range_map"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Draw MMBench Radar Graph\n",
"data = json.loads(open('OpenVLM.json').read())['results']\n",
"models = list(data)\n",
"print(models)\n",
"\n",
"# model2vis = [\n",
"# 'GPT-4v (detail: low)', 'GeminiProVision', 'Qwen-VL-Plus', \n",
"# 'InternLM-XComposer2-VL', 'LLaVA-v1.5-13B', 'CogVLM-17B-Chat',\n",
"# 'mPLUG-Owl2', 'Qwen-VL-Chat', 'IDEFICS-80B-Instruct'\n",
"# ]\n",
"\n",
"model2vis = [\n",
" # 'GPT-4v (detail: low)', 'GeminiProVision', 'InternLM-XComposer2-VL', \n",
" 'GPT-4v (1106, detail-low)', 'Gemini-1.0-Pro', 'Gemini-1.5-Pro', #'Gemini-1.5-Flash', 'Qwen-VL-Plus', \n",
" 'InternLM-XComposer2', 'LLaVA-v1.5-13B', 'CogVLM-17B-Chat',\n",
" 'mPLUG-Owl2', 'Qwen-VL-Chat', 'IDEFICS-80B-Instruct'\n",
"]\n",
"\n",
"colors = [\n",
" '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', \n",
" '#e377c2', '#7f7f7f', '#bcbd22'\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"\n",
"split = 'MMBench_TEST_EN'\n",
"# data_sub = {k: v[split] for k, v in data.items()}\n",
"data_sub = {k: defaultdict(int, v)[split] for k, v in data.items()}\n",
"# solve the problem that some model lack the evaluation of MMBench_TEST_EN\n",
"\n",
"labels = list(data_sub[model2vis[0]])\n",
"labels.remove('Overall')\n",
"num_vars = len(labels)\n",
"\n",
"raw_data = [data_sub[m] for m in model2vis]\n",
"data_list, range_map = pre_normalize(raw_data, labels)\n",
"\n",
"alpha = 0.25\n",
"angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()\n",
"angles_deg = np.linspace(0, 360, num_vars, endpoint=False).tolist()\n",
"fig, ax_base = plt.subplots(nrows=1, ncols=1, figsize=(10, 10), subplot_kw=dict(polar=True))\n",
"\n",
"for i in range(len(data_list)):\n",
" item = data_list[i]\n",
" model_name = model2vis[i]\n",
" color = colors[i]\n",
" tmp_angles = angles[:] + [angles[0]]\n",
" tmp_values = [item[lb] for lb in labels] + [item[labels[0]]]\n",
" ax_base.plot(tmp_angles, tmp_values, color=color, linewidth=1, linestyle='solid', label=model_name)\n",
" ax_base.fill(tmp_angles, tmp_values, color=color, alpha=alpha)\n",
" \n",
"angles += [angles[0]]\n",
"ax_base.set_ylim(0, 100)\n",
"ax_base.set_yticks([40, 60, 80, 100])\n",
"ax_base.set_yticklabels([''] * 4)\n",
"\n",
"ax_base.tick_params(pad=25)\n",
"ax_base.set_xticks(angles[:-1])\n",
"ax_base.set_xticklabels(labels, fontproperties=font18)\n",
"\n",
"leg = ax_base.legend(loc='center right', bbox_to_anchor=(1.6, 0.5), prop=font15, ncol=1, frameon=True, labelspacing=1.2)\n",
"for line in leg.get_lines():\n",
" line.set_linewidth(2.5)\n",
"\n",
"cx, cy, sz = 0.44, 0.435, 0.34\n",
"axes = [fig.add_axes([cx - sz, cy - sz, cx + sz, cy + sz], projection='polar', label='axes%d' % i) for i in range(num_vars)]\n",
" \n",
"for ax, angle, label in zip(axes, angles_deg, labels):\n",
" ax.patch.set_visible(False)\n",
" ax.grid(False)\n",
" ax.xaxis.set_visible(False)\n",
" cur_range = range_map[label]\n",
" label_list = [cur_range[0] + (cur_range[1] - cur_range[0]) / 5 * i for i in range(2, 6)]\n",
" label_list = [f'{x:.1f}' for x in label_list]\n",
" ax.set_rgrids(range(40, 120, 20), angle=angle, labels=label_list, font_properties=font12)\n",
" ax.spines['polar'].set_visible(False)\n",
" ax.set_ylim(0, 100)\n",
"\n",
"title_text = f'{len(model2vis)} Representative VLMs on MMBench Test.'\n",
"plt.figtext(.7, .95, title_text, fontproperties=font18, ha='center')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"labels = ['SEEDBench_IMG', 'CCBench', 'MMBench_TEST_EN', 'MMBench_TEST_CN', 'MME', 'MMVet', 'MMMU_VAL', 'MathVista', 'HallusionBench', 'LLaVABench']\n",
"num_vars = len(labels)\n",
"\n",
"raw_data = [{k: data[m][k]['Overall'] for k in labels} for m in model2vis]\n",
"data_list, range_map = pre_normalize(raw_data, labels)\n",
"\n",
"alpha = 0.25\n",
"angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()\n",
"angles_deg = np.linspace(0, 360, num_vars, endpoint=False).tolist()\n",
"fig, ax_base = plt.subplots(nrows=1, ncols=1, figsize=(10, 10), subplot_kw=dict(polar=True))\n",
"\n",
"for i in range(len(data_list)):\n",
" item = data_list[i]\n",
" model_name = model2vis[i]\n",
" color = colors[i]\n",
" tmp_angles = angles[:] + [angles[0]]\n",
" tmp_values = [item[lb] for lb in labels] + [item[labels[0]]]\n",
" ax_base.plot(tmp_angles, tmp_values, color=color, linewidth=1, linestyle='solid', label=model_name)\n",
" ax_base.fill(tmp_angles, tmp_values, color=color, alpha=alpha)\n",
" \n",
"angles += [angles[0]]\n",
"ax_base.set_ylim(0, 100)\n",
"ax_base.set_yticks([40, 60, 80, 100])\n",
"ax_base.set_yticklabels([''] * 4)\n",
"\n",
"ax_base.tick_params(pad=15)\n",
"ax_base.set_xticks(angles[:-1])\n",
"ax_base.set_xticklabels(labels, fontproperties=font18)\n",
"\n",
"dataset_map = {\n",
" 'MMBench_TEST_EN': 'MMBench (Test)', \n",
" 'MMBench_TEST_CN': 'MMBenchCN (Test)', \n",
" 'MathVista': 'MathVista (TestMini)', \n",
" 'MMMU_VAL': 'MMMU (Val)'\n",
"}\n",
"for i, label in enumerate(ax_base.get_xticklabels()):\n",
" x,y = label.get_position()\n",
" text = label.get_text()\n",
" text = dataset_map[text] if text in dataset_map else text\n",
" lab = ax_base.text(x, y, text, transform=label.get_transform(),\n",
" ha=label.get_ha(), va=label.get_va(), font_properties=font15)\n",
" lab.set_rotation(360 / num_vars * i + 270)\n",
" labels.append(lab)\n",
"ax_base.set_xticklabels([])\n",
"\n",
"leg = ax_base.legend(loc='center right', bbox_to_anchor=(1.6, 0.5), prop=font15, ncol=1, frameon=True, labelspacing=1.2)\n",
"for line in leg.get_lines():\n",
" line.set_linewidth(2.5)\n",
"\n",
"cx, cy, sz = 0.44, 0.435, 0.34\n",
"axes = [fig.add_axes([cx - sz, cy - sz, cx + sz, cy + sz], projection='polar', label='axes%d' % i) for i in range(num_vars)]\n",
" \n",
"for ax, angle, label in zip(axes, angles_deg, labels):\n",
" ax.patch.set_visible(False)\n",
" ax.grid(False)\n",
" ax.xaxis.set_visible(False)\n",
" cur_range = range_map[label]\n",
" label_list = [cur_range[0] + (cur_range[1] - cur_range[0]) / 5 * i for i in range(2, 6)]\n",
" label_list = [f'{x:.1f}' for x in label_list]\n",
" ax.set_rgrids(range(40, 120, 20), angle=angle, labels=label_list, font_properties=font12)\n",
" ax.spines['polar'].set_visible(False)\n",
" ax.set_ylim(0, 100)\n",
"\n",
"title_text = f'{len(model2vis)} Representative VLMs on {num_vars} Benchmarks in OpenCompass Multi-Modal Leaderboard.'\n",
"plt.figtext(.7, .95, title_text, fontproperties=font18, ha='center')\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import re
import sys
from os.path import exists
from setuptools import find_packages, setup
def parse_requirements(fname='requirements.txt', with_version=True):
"""Parse the package dependencies listed in a requirements file but strips
specific versioning information.
Args:
fname (str): path to requirements file
with_version (bool, default=False): if True include version specs
Returns:
List[str]: list of requirements items
CommandLine:
python -c "import setup; print(setup.parse_requirements())"
"""
require_fpath = fname
def parse_line(line):
"""Parse information from a line in a requirements text file."""
if line.startswith('-r '):
# Allow specifying requirements in other files
target = line.split(' ')[1]
for info in parse_require_file(target):
yield info
else:
info = {'line': line}
if line.startswith('-e '):
info['package'] = line.split('#egg=')[1]
elif '@git+' in line:
info['package'] = line
else:
# Remove versioning from the package
pat = '(' + '|'.join(['>=', '==', '>']) + ')'
parts = re.split(pat, line, maxsplit=1)
parts = [p.strip() for p in parts]
info['package'] = parts[0]
if len(parts) > 1:
op, rest = parts[1:]
if ';' in rest:
# Handle platform specific dependencies
# http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
version, platform_deps = map(str.strip,
rest.split(';'))
info['platform_deps'] = platform_deps
else:
version = rest # NOQA
info['version'] = (op, version)
yield info
def parse_require_file(fpath):
with open(fpath, 'r') as f:
for line in f.readlines():
line = line.strip()
if line and not line.startswith('#'):
for info in parse_line(line):
yield info
def gen_packages_items():
if exists(require_fpath):
for info in parse_require_file(require_fpath):
parts = [info['package']]
if with_version and 'version' in info:
parts.extend(info['version'])
if not sys.version.startswith('3.4'):
# apparently package_deps are broken in 3.4
platform_deps = info.get('platform_deps')
if platform_deps is not None:
parts.append(';' + platform_deps)
item = ''.join(parts)
yield item
packages = list(gen_packages_items())
return packages
with open('README.md') as f:
readme = f.read()
def do_setup():
setup(
name='vlmeval',
version='0.1.0',
description='OpenCompass VLM Evaluation Kit',
author='Haodong Duan',
author_email='dhd.efz@gmail.com',
maintainer='Haodong Duan',
maintainer_email='dhd.efz@gmail.com',
long_description=readme,
long_description_content_type='text/markdown',
cmdclass={},
install_requires=parse_requirements('requirements.txt'),
setup_requires=[],
python_requires='>=3.7.0',
packages=find_packages(exclude=[
'test*',
'paper_test*',
]),
keywords=['AI', 'NLP', 'in-context learning'],
entry_points={
'console_scripts': ['vlmutil = vlmeval:cli']
},
classifiers=[
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
])
if __name__ == '__main__':
do_setup()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment