Unverified Commit fb963f0f authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Multimodal prototyping (#2243)



* add WIP hf vlm class

* add doc_to_image

* add mmmu tasks

* fix merge conflicts

* add lintang's changes to hf_vlms.py

* fix doc_to_image

* added yaml_path for config-loading

* revert

* add line to process str type v

* update

* modeling cleanup

* add aggregation for mmmu

* rewrite MMMU processing code based on only MMMU authors' repo (doc_to_image still WIP)

* implemented doc_to_image

* update doc_to_image to accept list of features

* update functions

* readd image processed

* update args process

* bugfix for repeated images fed to model

* push WIP loglikelihood code

* commit most recent code (generative ; qwen2-vl testing)

* preliminary image_token_id handling

* small mmmu update: some qs have >4 mcqa options

* push updated modeling code

* use processor.apply_chat_template

* add mathvista draft

* nit

* nit

* ensure no footguns in text<>multimodal LM<>task incompatibility

* add notification to readme regarding launch of prototype!

* fix compatibility check

* reorganize mmmu configs

* chat_template=None

* add interleave chat_template

* add condition

* add max_images; interleave=true

* nit

* testmini_mcq

* nit

* pass image string; convert img

* add vllm

* add init

* vlm add multi attr

* fixup

* pass max images to vllm model init

* nit

* encoding to device

* fix HFMultimodalLM.chat_template ?

* add mmmu readme

* remove erroneous prints

* use HFMultimodalLM.chat_template ; restore tasks/__init__.py

* add docstring for replace_placeholders in utils

* fix `replace_placeholders`; set image_string=None

* fix typo

* cleanup + fix merge conflicts

* update MMMU readme

* del mathvista

* add some sample scores

* Update README.md

* add log msg for image_string value

---------
Co-authored-by: default avatarhaileyschoelkopf <hailey@eleuther.ai>
Co-authored-by: default avatarBaber Abbasi <baber@eleuther.ai>
Co-authored-by: default avatarBaber <baber@hey.com>
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent decc533d
......@@ -6,6 +6,7 @@
*Latest News 📣*
- [2024/09] We are prototyping allowing users of LM Evaluation Harness to create and evaluate on text+image multimodal input, text output tasks, and have just added the `hf-multimodal` and `vllm-vlm` model types and `mmmu` task as a prototype feature. We welcome users to try out this in-progress feature and stress-test it for themselves, and suggest they check out [`lmms-eval`](https://github.com/EvolvingLMMs-Lab/lmms-eval), a wonderful project originally forking off of the lm-evaluation-harness, for a broader range of multimodal tasks, models, and features.
- [2024/07] [API model](docs/API_guide.md) support has been updated and refactored, introducing support for batched and async requests, and making it significantly easier to customize and use for your own purposes. **To run Llama 405B, we recommend using VLLM's OpenAI-compliant API to host the model, and use the `local-completions` model type to evaluate the model.**
- [2024/07] New Open LLM Leaderboard tasks have been added ! You can find them under the [leaderboard](lm_eval/tasks/leaderboard/README.md) task group.
......
......@@ -75,6 +75,7 @@ class TaskConfig(dict):
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str] = None
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
......@@ -377,6 +378,10 @@ class Task(abc.ABC):
def doc_to_target(self, doc):
pass
# not an abstractmethod because not every language-only task has to implement this
def doc_to_image(self, doc):
raise NotImplementedError
def build_all_requests(
self,
*,
......@@ -735,6 +740,10 @@ class ConfigurableTask(Task):
)
self.OUTPUT_TYPE = self.config.output_type
if self.config.doc_to_image is not None:
# mark the task as requiring multimodality.
self.MULTIMODAL = True
if self.config.dataset_path is not None:
self.DATASET_PATH = self.config.dataset_path
......@@ -1042,8 +1051,8 @@ class ConfigurableTask(Task):
Whether to apply the chat template to the fewshot context.
:param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param chat_template: Callable
Chat template to be applied to the fewshot context.
:param chat_template:
callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
:returns: str
The fewshot context.
"""
......@@ -1279,9 +1288,34 @@ class ConfigurableTask(Task):
else:
raise TypeError
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
if doc_to_image is not None:
doc_to_image = doc_to_image
elif self.config.doc_to_image is not None:
doc_to_image = self.config.doc_to_image
else:
return None
if isinstance(doc_to_image, list):
image_feature = [
self.doc_to_image(doc, feature) for feature in doc_to_image
]
return [feature for feature in image_feature if feature is not None]
elif isinstance(doc_to_image, str):
if doc_to_image in self.features:
return doc[doc_to_image]
else:
return ast.literal_eval(utils.apply_template(doc_to_image, doc))
elif callable(doc_to_image):
return doc_to_image(doc)
else:
return None
def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
aux_arguments = None
if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
......@@ -1299,6 +1333,37 @@ class ConfigurableTask(Task):
# Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
aux_arguments = [("", f"{choice}") for choice in choices]
arguments.extend(aux_arguments)
elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))
multimodal_arg = {}
if (
self.config.doc_to_image
): # TODO: ensure that non-multimodal tasks aren't getting visual args
multimodal_arg = {
**multimodal_arg,
**{"visual": self.doc_to_image(doc)},
}
if bool(multimodal_arg):
if isinstance(arguments, list):
arguments = [arg + (multimodal_arg,) for arg in arguments]
else:
arguments = arguments + (multimodal_arg,)
if self.OUTPUT_TYPE == "multiple_choice":
request_list = [
Instance(
request_type="loglikelihood",
......@@ -1309,33 +1374,15 @@ class ConfigurableTask(Task):
)
for i, arg in enumerate(arguments)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
request_list.extend(
[
Instance(
request_type="loglikelihood",
doc=doc,
arguments=("", "{}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in enumerate(choices)
]
)
return request_list
elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))
return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=arguments,
idx=0,
**kwargs,
)
def process_results(self, doc, results):
......@@ -1547,7 +1594,7 @@ class ConfigurableTask(Task):
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
f"output_type={self.OUTPUT_TYPE},"
f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
f"num_samples={len(self.eval_docs)})"
f"num_samples={len(self.eval_docs)})",
)
......
......@@ -414,8 +414,28 @@ def evaluate(
for task_output in eval_tasks
):
raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
# validation check: are we running multimodal task <-> non-multimodal model class, or vice-versa.
incompatible_tasks = []
for task_output in eval_tasks:
task: Task = task_output.task
if getattr(lm, "MULTIMODAL", False) != getattr(task, "MULTIMODAL", False):
incompatible_tasks.append(task_output.task_name)
if len(incompatible_tasks) > 0:
if not getattr(lm, "MULTIMODAL", False):
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
)
else:
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
)
# end multimodality validation check
for task_output in eval_tasks:
task: Task = task_output.task
limit = get_sample_size(task, limit)
task.build_all_requests(
limit=limit,
......
......@@ -3,6 +3,7 @@ from . import (
api_models,
dummy,
gguf,
hf_vlms,
huggingface,
mamba_lm,
nemo_lm,
......@@ -12,6 +13,7 @@ from . import (
optimum_lm,
textsynth,
vllm_causallms,
vllm_vlms,
)
......
This diff is collapsed.
......@@ -448,7 +448,16 @@ class HFLM(TemplateLM):
Helper method during initialization.
Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder))
model type to be used.
sets `self.AUTO_MODEL_CLASS` appropriately if not already set.
"""
# escape hatch: if we're using a subclass that shouldn't follow
# the default _get_backend logic,
# then skip over the method.
# TODO: this seems very much undesirable in some cases--our code in HFLM
# references AutoModelForCausalLM at times to check for equality
if self.AUTO_MODEL_CLASS is not None:
return
assert backend in ["default", "causal", "seq2seq"]
if backend != "default":
......
......@@ -664,3 +664,37 @@ def configure_pad_token(
tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
return tokenizer
def replace_placeholders(
string: str, default_placeholder: str, image_token: str, max_images: int
):
"""
A utility function used for local multimodal models. It locates all `placeholder` string
occurrences in the given input `string_` and replaces the first `max_count` instances with
`replacement`, and all subsequent occurrences with the empty string.
This is used to replace <image> placeholder tags by model-specific image tokens like <|image_pad|>
and to allow for only the first `max_count` images to be passed to a model if desired.
:param string: The original string containing placeholders.
:param default_placeholder: The placeholder text to be replaced.
:param image_token: The token to replace the placeholder with.
:param max_images: The maximum number of replacements to make.
:return: The string with placeholders replaced.
"""
count = 0
result = []
parts = string.split(default_placeholder)
for part in parts[:-1]: # Iterate through all but the last part
result.append(part)
if count < max_images:
result.append(image_token)
count += 1
elif default_placeholder != image_token:
result.append(default_placeholder)
# Add the last part of the string
result.append(parts[-1])
return "".join(result)
import copy
from typing import Dict, List, Optional
import transformers
from more_itertools import distribute
from tqdm import tqdm
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model
from lm_eval.models.utils import Collator, undistribute
from lm_eval.models.vllm_causallms import VLLM
from lm_eval.utils import simple_parse_args_string
try:
import ray
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest # noqa: F401
from vllm.transformers_utils.tokenizer import get_tokenizer # noqa: F401
except ModuleNotFoundError:
pass
DEFAULT_IMAGE_PLACEHOLDER = "<image>"
@register_model("vllm-vlm")
class VLLM_VLM(VLLM):
MULTIMODAL = True
def __init__(
self,
pretrained: str,
trust_remote_code: Optional[bool] = False,
revision: Optional[str] = None,
interleave: bool = True,
# TODO<baber>: handle max_images and limit_mm_per_prompt better
max_images: int = 999,
limit_mm_per_prompt: str = "image=1",
**kwargs,
):
kwargs["limit_mm_per_prompt"] = simple_parse_args_string(limit_mm_per_prompt)
super().__init__(
pretrained=pretrained,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
self.interleave = interleave
self.max_images = max_images
self.processor = transformers.AutoProcessor.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
)
self.chat_applied: bool = False
def tok_batch_multimodal_encode(
self,
strings: List[str], # note that input signature of this fn is different
images, # TODO: typehint on this
left_truncate_len: int = None,
truncation: bool = False,
):
images = [img[: self.max_images] for img in images]
outputs = []
for x, i in zip(strings, images):
inputs = {
"prompt": x,
"multi_modal_data": {"image": i},
}
outputs.append(inputs)
return outputs
def _model_generate(
self,
requests: List[List[dict]] = None,
generate: bool = False,
max_tokens: int = None,
stop: Optional[List[str]] = None,
**kwargs,
):
if generate:
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else:
sampling_params = SamplingParams(
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
)
if self.data_parallel_size > 1:
# vLLM hangs if tensor_parallel > 1 and resources are set in ray.remote
# also seems to only work with decorator and not with ray.remote() fn
# see https://github.com/vllm-project/vllm/issues/973
# note: this has changed on 0.3.3, and it only works now if num_gpus are set.
# but then tensor_parallel breaks
@ray.remote
def run_inference_one_model(
model_args: dict, sampling_params, requests: List[List[dict]]
):
llm = LLM(**model_args)
return llm.generate(requests, sampling_params=sampling_params)
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
inputs = ((self.model_args, sampling_params, req) for req in requests)
object_refs = [run_inference_one_model.remote(*x) for x in inputs]
results = ray.get(object_refs)
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
ray.shutdown()
# flatten results
return undistribute(results)
if self.lora_request is not None:
outputs = self.model.generate(
requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
lora_request=self.lora_request,
)
else:
outputs = self.model.generate(
requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
)
return outputs
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
self.chat_applied = True
if not self.interleave:
for content in chat_history:
c = []
text = content["content"]
# Count and remove image placeholders
image_count = min(
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
)
text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "")
# Add image entries
for _ in range(image_count):
c.append({"type": "image", "image": None})
# Add single text entry at the end
c.append({"type": "text", "text": text})
content["content"] = c
else:
for content in chat_history:
c = []
text = content["content"]
expected_image_count = min(
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
)
actual_image_count = 0
text_parts = text.split(DEFAULT_IMAGE_PLACEHOLDER)
for i, part in enumerate(text_parts):
# TODO: concatenate text parts (esp. if skipping images)?
if part: # Add non-empty text parts
c.append({"type": "text", "text": part})
if (
(i < len(text_parts) - 1) and i < self.max_images
): # Add image placeholder after each split except the last
c.append({"type": "image"})
actual_image_count += 1
content["content"] = c
if actual_image_count != expected_image_count:
raise ValueError(
f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}"
)
return self.processor.apply_chat_template(
chat_history, add_generation_prompt=True
)
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
# TODO: support text-only reqs
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0])
return -len(toks), x[0]
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests with text+image input",
)
# TODO: port auto-batch sizing into this.
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = Collator(
[reg.args for reg in requests],
_collate,
group_by="gen_kwargs",
group_fn=lambda x: x[1],
)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
for chunk in chunks:
contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
visuals = [arg["visual"] for arg in aux_arguments]
if not isinstance(contexts, list):
contexts = list(
contexts
) # for Qwen2-VL, processor is unhappy accepting a tuple of strings instead of a list.
# TODO: could we upstream this workaround to HF?
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
# add EOS token to stop sequences
eos = self.tokenizer.decode(self.eot_token_id)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
max_ctx_len = self.max_length - max_gen_toks
inputs = self.tok_batch_multimodal_encode(
contexts,
visuals,
left_truncate_len=max_ctx_len,
)
cont = self._model_generate(inputs, stop=until, generate=True, **kwargs)
for output, context in zip(cont, contexts):
generated_text = output.outputs[0].text
res.append(generated_text)
self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)
pbar.close()
return res
# MMMU Benchmark
### Paper
Title: `MMMU: A Massive Multi-discipline MultimodalUnderstanding and Reasoning Benchmark for Expert AGI`
Abstract: `MMMU is a new benchmark designed to evaluate multimodal models on massive multi-discipline tasks demanding college-level subject knowledge and deliberate reasoning.`
`The benchmark is composed of 30 tasks, for a total of 900 mixed image+text examples (some with multiple images in context)`
Homepage: `https://github.com/MMMU-Benchmark/MMMU/tree/main/mmmu`
Note: Some questions have multiple images in context. To control for this use `max_images=N` in model init.
### Citation
```
@inproceedings{yue2023mmmu,
title={MMMU: A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI},
author={Xiang Yue and Yuansheng Ni and Kai Zhang and Tianyu Zheng and Ruoqi Liu and Ge Zhang and Samuel Stevens and Dongfu Jiang and Weiming Ren and Yuxuan Sun and Cong Wei and Botao Yu and Ruibin Yuan and Renliang Sun and Ming Yin and Boyuan Zheng and Zhenzhu Yang and Yibo Liu and Wenhao Huang and Huan Sun and Yu Su and Wenhu Chen},
booktitle={Proceedings of CVPR},
year={2024},
}
```
### Groups, Tags, and Tasks
#### Groups
* `mmmu_val`
* `mmmu_val_art_and_design`
* `mmmu_val_business`
* `mmmu_val_health_and_medicine`
* `mmmu_val_humanities_and_social_science`
* `mmmu_val_science`
* `mmmu_val_tech_and_engineering`
#### Tags
#### Tasks
* `mmmu_val_accounting`
* `mmmu_val_agriculture`
* `mmmu_val_architecture_and_engineering.yaml`
* `mmmu_val_art`
* `mmmu_val_art_theory`
* `mmmu_val_basic_medical_science`
* `mmmu_val_biology`
* `mmmu_val_chemistry`
* `mmmu_val_computer_science`
* `mmmu_val_clinical_medicine`
* `mmmu_val_design`
* `mmmu_val_diagnostics_and_laboratory_medicine`
* `mmmu_val_electronics`
* `mmmu_val_energy_and_power`
* `mmmu_val_economics`
* `mmmu_val_finance`
* `mmmu_val_geography`
* `mmmu_val_history`
* ...
### Variants
The `mmmu_val` group implements MMMU using processing code [from the original MMMU authors](https://github.com/MMMU-Benchmark/MMMU/tree/main/mmmu) and uses the prompt format found in [the MMMU repository for Llava-1.5](https://github.com/MMMU-Benchmark/MMMU/blob/main/mmmu/configs/llava1.5.yaml). This implementation should give scores on par with or slightly higher than those reported by [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval/tree/main/lmms_eval/tasks/mmmu) for `mmmu_val` and the MMMU repository code.
Scores on several tested models (**all with `--apply_chat_template`**) are:
Qwen2-VL-2B:
```
hf-multimodal (pretrained=Qwen/Qwen2-VL-2B-Instruct,attn_implementation=flash_attention_2,dtype=bfloat16,convert_img_format=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 2
```
```
| Groups |Version|Filter|n-shot|Metric| |Value | |Stderr|
|--------------------------------|------:|------|------|------|---|-----:|---|-----:|
|mmmu_val | 0|none | |acc |↑ |0.3778|± |0.0155|
| - Art and Design | 0|none | |acc |↑ |0.5500|± |0.0415|
| - Business | 0|none | |acc |↑ |0.3600|± |0.0389|
| - Health and Medicine | 0|none | |acc |↑ |0.3667|± |0.0394|
| - Humanities and Social Science| 0|none | |acc |↑ |0.5167|± |0.0438|
| - Science | 0|none | |acc |↑ |0.2467|± |0.0352|
| - Tech and Engineering | 0|none | |acc |↑ |0.3143|± |0.0317|
```
Author-reported score: 41.1%
Qwen2-VL-7B:
```
hf-multimodal (pretrained=Qwen/Qwen2-VL-7B-Instruct,attn_implementation=flash_attention_2,dtype=bfloat16,convert_img_format=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 2
```
```
| Groups |Version|Filter|n-shot|Metric| |Value | |Stderr|
|--------------------------------|------:|------|------|------|---|-----:|---|-----:|
|mmmu_val | 0|none | |acc |↑ |0.5056|± |0.0160|
| - Art and Design | 0|none | |acc |↑ |0.6917|± |0.0398|
| - Business | 0|none | |acc |↑ |0.4333|± |0.0406|
| - Health and Medicine | 0|none | |acc |↑ |0.5667|± |0.0401|
| - Humanities and Social Science| 0|none | |acc |↑ |0.6750|± |0.0426|
| - Science | 0|none | |acc |↑ |0.3800|± |0.0392|
| - Tech and Engineering | 0|none | |acc |↑ |0.4000|± |0.0341|
```
Author-reported score: 54.1%
Idefics2-8B:
```
hf-multimodal (pretrained=HuggingFaceM4/idefics2-8b,attn_implementation=flash_attention_2,dtype=bfloat16,convert_img_format=True,max_images=2), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 2
```
```
| Groups |Version|Filter|n-shot|Metric| |Value | |Stderr|
|--------------------------------|------:|------|------|------|---|-----:|---|-----:|
|mmmu_val | 0|none | |acc |↑ |0.4011|± |0.0154|
| - Art and Design | 0|none | |acc |↑ |0.6167|± |0.0436|
| - Business | 0|none | |acc |↑ |0.3200|± |0.0373|
| - Health and Medicine | 0|none | |acc |↑ |0.4000|± |0.0401|
| - Humanities and Social Science| 0|none | |acc |↑ |0.5750|± |0.0424|
| - Science | 0|none | |acc |↑ |0.2600|± |0.0358|
| - Tech and Engineering | 0|none | |acc |↑ |0.3381|± |0.0312|
```
Author-reported score: ~43%
Llava-v1.6-Mistral-7B:
```
hf-multimodal (pretrained=llava-hf/llava-v1.6-mistral-7b-hf,attn_implementation=flash_attention_2,dtype=bfloat16,convert_img_format=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 2
```
```
| Groups |Version|Filter|n-shot|Metric| |Value | |Stderr|
|--------------------------------|------:|------|------|------|---|-----:|---|-----:|
|mmmu_val | 0|none | |acc |↑ |0.3522|± |0.0151|
| - Art and Design | 0|none | |acc |↑ |0.5167|± |0.0440|
| - Business | 0|none | |acc |↑ |0.2667|± |0.0362|
| - Health and Medicine | 0|none | |acc |↑ |0.3867|± |0.0397|
| - Humanities and Social Science| 0|none | |acc |↑ |0.5917|± |0.0433|
| - Science | 0|none | |acc |↑ |0.2200|± |0.0342|
| - Tech and Engineering | 0|none | |acc |↑ |0.2524|± |0.0299|
```
Author-reported score: 35.3%
### Checklist
For adding novel benchmarks/datasets to the library:
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [x] Is the "Main" variant of this task clearly denoted?
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [x] Have you noted which, if any, published evaluation setups are matched by this variant?
group: mmmu_val_art_and_design
group_alias: Art and Design
task:
- mmmu_val_art
- mmmu_val_art_theory
- mmmu_val_design
- mmmu_val_music
aggregate_metric_list:
- metric: acc
aggregation: mean
weight_by_size: true
metadata:
version: 0.0
group: mmmu_val_business
group_alias: Business
task:
- mmmu_val_accounting
- mmmu_val_economics
- mmmu_val_finance
- mmmu_val_manage
- mmmu_val_marketing
aggregate_metric_list:
- metric: acc
aggregation: mean
weight_by_size: true
metadata:
version: 0.0
group: mmmu_val_health_and_medicine
group_alias: Health and Medicine
task:
- mmmu_val_basic_medical_science
- mmmu_val_clinical_medicine
- mmmu_val_diagnostics_and_laboratory_medicine
- mmmu_val_pharmacy
- mmmu_val_public_health
aggregate_metric_list:
- metric: acc
aggregation: mean
weight_by_size: true
metadata:
version: 0.0
group: mmmu_val_humanities_and_social_science
group_alias: Humanities and Social Science
task:
- mmmu_val_history
- mmmu_val_literature
- mmmu_val_sociology
- mmmu_val_psychology
aggregate_metric_list:
- metric: acc
aggregation: mean
weight_by_size: true
metadata:
version: 0.0
group: mmmu_val
task:
- mmmu_val_art_and_design
- mmmu_val_business
- mmmu_val_health_and_medicine
- mmmu_val_humanities_and_social_science
- mmmu_val_science
- mmmu_val_tech_and_engineering
aggregate_metric_list:
- metric: acc
aggregation: mean
weight_by_size: true
metadata:
version: 0.0
group: mmmu_val_science
group_alias: Science
task:
- mmmu_val_biology
- mmmu_val_chemistry
- mmmu_val_geography
- mmmu_val_math
- mmmu_val_physics
aggregate_metric_list:
- metric: acc
aggregation: mean
weight_by_size: true
metadata:
version: 0.0
group: mmmu_val_tech_and_engineering
group_alias: Tech and Engineering
task:
- mmmu_val_agriculture
- mmmu_val_architecture_and_engineering
- mmmu_val_computer_science
- mmmu_val_electronics
- mmmu_val_energy_and_power
- mmmu_val_materials
- mmmu_val_mechanical_engineering
aggregate_metric_list:
- metric: acc
aggregation: mean
weight_by_size: true
metadata:
version: 0.0
dataset_path: MMMU/MMMU
validation_split: validation
output_type: generate_until
doc_to_image: !function utils.doc_to_image
doc_to_text: !function utils.doc_to_text
doc_to_target: "answer"
process_results: !function utils.process_results
generation_kwargs:
until:
- "<|endoftext|>"
temperature: 0.0
do_sample: false
max_gen_toks: 512
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
metadata:
version: 0.0
task: mmmu_val_accounting
include: _template_yaml
task_alias: Accounting
dataset_name: Accounting
task: mmmu_val_agriculture
include: _template_yaml
task_alias: Agriculture
dataset_name: Agriculture
task: mmmu_val_architecture_and_engineering
include: _template_yaml
task_alias: Architecture and Engineering
dataset_name: Architecture_and_Engineering
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment