Commit 1dda496f authored by Ashvin Nihalani's avatar Ashvin Nihalani
Browse files

Adding LLaVa support

Updating APIs for MM support

Adding MLLM dependencies

Rebase off mainline
parent 568af943
......@@ -6,6 +6,10 @@ OutputType = Literal[
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
]
InputType = Literal[
"text", "text_image"
]
@dataclass
class Instance:
......@@ -13,6 +17,10 @@ class Instance:
doc: dict
arguments: tuple
idx: int
# Input type for multimodal
input_type: InputType = "text"
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None)
)
......@@ -24,6 +32,7 @@ class Instance:
doc_id: Optional[int] = None
repeats: Optional[int] = None
def __post_init__(self) -> None:
# unpack metadata field
self.task_name, self.doc_id, self.repeats = self.metadata
......
......@@ -36,7 +36,7 @@ class LM(abc.ABC):
LM calls whenever possible.
:param requests: list[Instance]
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
A list of Instance objects, with property `args` which returns a tuple (context, continuation, visual_list).
`context: str`
Context string. Implementations of LM must be able to handle an
empty context string.
......@@ -44,6 +44,8 @@ class LM(abc.ABC):
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
'visual_list: Optional[list[dict]]'
Visual Input to the model. Can be None
:return: list[tuple[float, bool]]
A list of pairs (logprob, isgreedy)
......
......@@ -26,7 +26,7 @@ from tqdm import tqdm
from lm_eval import utils
from lm_eval.api import samplers
from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.instance import Instance, OutputType, InputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
......@@ -48,6 +48,11 @@ ALL_OUTPUT_TYPES = [
"generate_until",
]
ALL_INPUT_TYPES = [
"text",
"text_image",
]
eval_logger = logging.getLogger("lm-eval")
......@@ -75,6 +80,7 @@ class TaskConfig(dict):
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_visual: Union[Callable, str] = None
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
......@@ -87,6 +93,7 @@ class TaskConfig(dict):
# scoring options
metric_list: Optional[list] = None
output_type: OutputType = "generate_until"
input_type: InputType = "text"
generation_kwargs: Optional[dict] = None
repeats: int = 1
filter_list: Optional[Union[str, list]] = None
......@@ -365,6 +372,10 @@ class Task(abc.ABC):
def doc_to_target(self, doc):
pass
@abc.abstractmethod
def doc_to_visual(self, doc):
pass
def build_all_requests(
self,
*,
......@@ -722,6 +733,13 @@ class ConfigurableTask(Task):
)
self.OUTPUT_TYPE = self.config.output_type
if self.config.input_type is not None:
if self.config.input_type not in ALL_INPUT_TYPES:
raise ValueError(
f"Got invalid output_type '{self.config.input_type}', must be in '{','.join(ALL_INPUT_TYPES)}'"
)
self.INPUT_TYPE = self.config.input_type
if self.config.dataset_path is not None:
self.DATASET_PATH = self.config.dataset_path
......@@ -1260,6 +1278,15 @@ class ConfigurableTask(Task):
else:
raise TypeError
def doc_to_visual(self, doc:dict) -> Union[int, str, list]:
if type(self.config.doc_to_visual) is str:
assert self.config.doc_to_visual in self.features
# Single Image. Still return a list for consistency
return doc[self.config.doc_to_visual]
else:
assert callable(self.config.doc_to_visual)
return self.config.doc_to_visual(doc)
def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
......@@ -1313,10 +1340,13 @@ class ConfigurableTask(Task):
return request_list
elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))
if self.INPUT_TYPE == "text_image":
arguments = (ctx, deepcopy(self.config.generation_kwargs), self.doc_to_visual, doc, self.config.task)
elif self.INPUT_TYPE == "text":
arguments = (ctx, deepcopy(self.config.generation_kwargs))
return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
request_type=self.OUTPUT_TYPE, input_type=self.INPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
)
def process_results(self, doc, results):
......@@ -1524,6 +1554,7 @@ class ConfigurableTask(Task):
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
f"group_name={getattr(self.config, 'group', None)},"
f"output_type={self.OUTPUT_TYPE},"
f"input_type={self.INPUT_TYPE}",
f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
f"num_samples={len(self.eval_docs)})"
)
......
......@@ -5,6 +5,7 @@ from . import (
huggingface,
mamba_lm,
nemo_lm,
llava,
neuralmagic,
neuron_optimum,
openai_completions,
......
import copy
from tqdm import tqdm
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from accelerate import Accelerator, DistributedType
from accelerate.state import AcceleratorState
import logging
import torch
from typing import List, Optional, Union, Tuple
import warnings
from lm_eval.models.utils import Collator
warnings.filterwarnings("ignore")
eval_logger = logging.getLogger("lm-eval")
try:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, \
IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle
except ImportError:
eval_logger.error("LLaVA is not installed. Please install LLaVA to use this model.")
@register_model("llava")
class Llava(LM):
"""
Llava Model
"""
def __init__(
self,
pretrained: str = "liuhaotian/llava-v1.5-7b",
truncation: Optional[bool] = True,
device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto",
batch_size: Optional[Union[int, str]] = 1,
trust_remote_code: Optional[bool] = False,
revision=None,
use_flash_attention_2=False,
conv_template="vicuna_v1",
use_cache=True,
truncate_context=False, # whether to truncate the context in generation, set it False for LLaVA-1.6
**kwargs,
) -> None:
super().__init__()
# Do not use kwargs for now
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
accelerator = Accelerator()
if accelerator.num_processes > 1:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
else:
self._device = device
(
self._tokenizer,
self._model,
self._image_processor,
self._max_length,
) = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self._device)
self._config = self._model.config
self.model.eval()
self.model.tie_weights()
self.truncation = truncation
self.batch_size_per_gpu = int(batch_size)
self.conv_template = conv_template
self.use_cache = use_cache
self.truncate_context = truncate_context
# assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
if accelerator.num_processes > 1:
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU,
DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
eval_logger.info(
"Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
else:
self.model.to(self._device)
self._rank = 0
self._word_size = 1
@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config
@property
def tokenizer(self):
return self._tokenizer
@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
return self._max_length
def pad_sequence(self, input_ids, batch_first, padding_value):
if self.tokenizer.padding_side == "left":
input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
if self.tokenizer.padding_side == "left":
input_ids = torch.flip(input_ids, [1])
return input_ids
@property
def batch_size(self):
return self.batch_size_per_gpu
@property
def device(self):
return self._device
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
""" """
add_special_tokens = False if add_special_tokens is None else add_special_tokens
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
encoding = encoding[-left_truncate_len:]
return encoding
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
# TODO
res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
for contexts, doc_to_target, doc_to_visual, doc, task in [reg.args for reg in requests]:
# encode, pad, and truncate contexts for this batch
if type(doc_to_target) == str:
continuation = doc_to_target
else:
continuation = doc_to_target(doc)
visuals = [doc_to_visual(doc)]
visuals = self.flatten(visuals)
if visuals:
image = process_images(visuals, self._image_processor, self._config)
if type(image) is list:
image = [_image.to(dtype=torch.float16, device=self.device) for _image in image]
else:
image = image.to(dtype=torch.float16, device=self.device)
else:
image = None
prompts_input = contexts[0]
if image is not None and len(image) != 0 and DEFAULT_IMAGE_TOKEN not in prompts_input:
"""
Three senarios:
1. No image, and there for, no image token should be added.
2. image token is already specified in the context, so we don't need to add it.
3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line.
"""
image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals)
image_tokens = " ".join(image_tokens)
prompts_input = image_tokens + "\n" + contexts[0]
conv = conv_templates[self.conv_template].copy()
conv.append_message(conv.roles[0], prompts_input)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
contxt_id = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(
0).to(self.device)
# Add the answer of the second role
conv.messages[1][1] = continuation
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(
0).to(self.device)
labels = input_ids.clone()
# Context part no need to calculate for loss
labels[0, : contxt_id.shape[1]] = -100
with torch.inference_mode():
outputs = self.model(input_ids=input_ids, labels=labels, images=image, use_cache=True)
loss = outputs["loss"]
# loss = torch.exp(loss)
logits = outputs["logits"]
greedy_tokens = logits.argmax(dim=-1)
cont_toks = input_ids[:, contxt_id.shape[1]:] # [1, seq]
greedy_tokens = greedy_tokens[:, contxt_id.shape[1]: input_ids.shape[1]] # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
res.append((float(loss.item()), bool(max_equal)))
pbar.update(1)
pbar.close()
return res
def loglikelihood_rolling(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
raise NotImplementedError()
def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list
def generate_until(self, requests: List[Instance]) -> List[str]:
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0])
return -len(toks), x[0]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = Collator(
[reg.arguments for reg in requests],
sort_fn=_collate,
group_by="gen_kwargs",
group_fn=lambda x: x[1],
)
chunks = re_ords.get_batched(n=self.batch_size)
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(
requests) // self.batch_size + 1
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
for chunk in chunks:
contexts, all_gen_kwargs, doc_to_visual, doc, task = zip(*chunk)
task = task[0]
visuals = [doc_to_visual[0](doc[0])]
visuals = self.flatten(visuals)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# Set default values for until and max_new_tokens
until = [self.tok_decode(self.eot_token_id)]
# Update values from gen_kwargs if present
if "until" in gen_kwargs:
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}")
if "image_aspect_ratio" in gen_kwargs.keys() and "image_aspect_ratio" not in self._config.__dict__:
# here we should pop it out of gen_kwargs so that it doesn't get passed to the model for next step of generation
self._config.image_aspect_ratio = gen_kwargs.pop("image_aspect_ratio")
eval_logger.info(f"Setting image aspect ratio: {self._config.image_aspect_ratio}")
# encode, pad, and truncate contexts for this batch
if visuals:
image_tensor = process_images(visuals, self._image_processor, self._config)
if type(image_tensor) is list:
image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor]
else:
image_tensor = image_tensor.to(dtype=torch.float16, device=self.device)
else:
image_tensor = None
# prompts_input = contexts[0]
question_input = []
for visual, context in zip(visuals, contexts):
if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context:
"""
Three senarios:
1. No image, and there for, no image token should be added.
2. image token is already specified in the context, so we don't need to add it.
3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line.
"""
image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visual) if isinstance(visual, list) else [
DEFAULT_IMAGE_TOKEN]
image_tokens = " ".join(image_tokens)
question = image_tokens + "\n" + context
else:
question = context
conv = conv_templates[self.conv_template].copy()
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()
question_input.append(prompt_question)
# The above for loop has bugs. When there is no visuals, e.g. pure text,
# there will be no for loop execute resulting in an empty question_input (because no visuals)
# Scenario 1 won't even be execute
if len(visuals) == 0:
for context in contexts:
question = context
conv = conv_templates[self.conv_template].copy()
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()
question_input.append(prompt_question)
# input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
# preconfigure gen_kwargs with defaults
gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
if "max_gen_toks" not in gen_kwargs:
gen_kwargs["max_gen_toks"] = 1024
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1
input_ids_list = [tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") for
prompt in question_input]
pad_token_ids = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
input_ids = self.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_token_ids).to(self.device)
attention_masks = input_ids.ne(pad_token_ids).to(self.device)
# These steps are not in LLaVA's original code, but are necessary for generation to work
# TODO: pay attention to this major generation step...
try:
cont = self.model.generate(
input_ids,
attention_mask=attention_masks,
pad_token_id=pad_token_ids,
images=image_tensor,
do_sample=gen_kwargs["do_sample"],
temperature=gen_kwargs["temperature"],
top_p=gen_kwargs["top_p"],
num_beams=gen_kwargs["num_beams"],
max_new_tokens=gen_kwargs["max_gen_toks"],
use_cache=self.use_cache,
)
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)
except Exception as e:
eval_logger.error(f"Error {e} in generating")
cont = ""
text_outputs = [""]
# cont_toks_list = cont.tolist()
# for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LMM
# if self.truncate_context:
# cont_toks = cont_toks[input_ids.shape[1] :]
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
# if self.truncate_context:
# for term in until:
# if len(term) > 0:
# # ignore '' separator,
# # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
# text_outputs = text_outputs.split(term)[0]
res.extend(text_outputs)
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)
pbar.close()
return res
\ No newline at end of file
dataset_path: lmms-lab/MMMU
task: "mmmu_val"
validation_split: validation
output_type: generate_until
input_type: text_image
doc_to_visual: !function utils.mmmu_doc_to_visual
doc_to_text: !function utils.mmmu_doc_to_text
doc_to_target: "answer"
# The return value of process_results will be used by metrics
process_results: !function utils.mmmu_process_results
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
generation_kwargs:
until:
- "<|endoftext|>"
temperature: 1.0
do_sample: false
max_gen_toks: 512
top_p: 0.94
repetition_penalty: 1.0
image_aspect_ratio: original
metric_list:
- metric: mmmu_acc
aggregation: !function utils.mmmu_aggregate_results
higher_is_better: true
\ No newline at end of file
from collections import defaultdict
import re
import ast
import random
import numpy as np
import logging
lmms_logger = logging.getLogger("lm-eval")
MULTI_CHOICE_PROMPT = "Answer with the option letter from the given choices directly."
OPEN_ENDED_PROMPT = "Answer the question using a single word or phrase."
def replace_images_tokens(input_string):
for i in range(1, 8):
question_text = f"<image {i}>"
query_text = "<image>"
if question_text in input_string:
input_string = input_string.replace(question_text, query_text)
return input_string
def parse_options(options):
option_letters = [chr(ord("A") + i) for i in range(len(options))]
choices_str = "\n".join([f"{option_letter}. {option}" for option_letter, option in zip(option_letters, options)])
return choices_str
def construct_prompt(doc):
question = doc["question"]
if doc["question_type"] == "multiple-choice":
# Weirdly, data["options"] is a string in MMMU Huggingface dataset
parsed_options = parse_options(ast.literal_eval(doc["options"]))
# parsed_options already prepends a newline so no need to add space here
question = f"{question}\n{parsed_options}\n{MULTI_CHOICE_PROMPT}"
else:
question = f"{question}\n{OPEN_ENDED_PROMPT}"
return question
def mmmu_doc_to_text(doc):
question = construct_prompt(doc)
return replace_images_tokens(question)
def mmmu_doc_to_visual(doc):
prompt = construct_prompt(doc)
image_tokens = re.findall(r"<image \d+>", prompt)
# Remove <> and swap space as _
image_tokens = [image_token.strip("<>").replace(" ", "_") for image_token in image_tokens]
visual = [doc[image_token].convert("RGB") for image_token in image_tokens]
return visual
def mmmu_process_results(doc, results):
pred = results[0]
if doc["question_type"] == "multiple-choice":
index2ans, all_choices = get_multi_choice_info(ast.literal_eval(doc["options"]))
parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans)
else:
parsed_pred = parse_open_response(pred)
id = doc["id"]
mmmu_acc = {"id": id, "subdomain": extract_subset_name(doc["id"]), "question_type": doc["question_type"], "answer": doc["answer"], "parsed_pred": parsed_pred}
return {
"mmmu_acc": mmmu_acc
}
def extract_subset_name(input_string):
# Define a regex pattern to match "validation_" at the beginning and "_<number>" at the end
split = input_string.split("_")[0]
pattern = re.compile(rf"^{split}_(.+?)_\d+$")
match = pattern.search(input_string)
if match:
return match.group(1)
else:
raise ValueError(f'No match found in "{input_string}"')
def mmmu_aggregate_results(results):
evaluation_result = {}
subset_to_eval_samples = defaultdict(list)
for result in results:
subset_to_eval_samples[result["subdomain"]].append(result)
for subset, sub_eval_samples in subset_to_eval_samples.items():
judge_dict, metric_dict = evaluate_mmmu(sub_eval_samples)
metric_dict.update({"num_example": len(sub_eval_samples)})
evaluation_result[subset] = metric_dict
printable_results = {}
for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():
in_domain_cat_results = {}
for cat_name in in_domain_cats:
if cat_name in evaluation_result.keys():
in_domain_cat_results[cat_name] = evaluation_result[cat_name]
else:
pass
in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
in_domain_data_num = sum([cat_results["num_example"] for cat_results in in_domain_cat_results.values()])
printable_results["Overall-" + domain] = {
"num": int(in_domain_data_num),
"mmmu_acc": round(in_domain_ins_acc, 3),
}
# add sub category
for cat_name, cat_results in in_domain_cat_results.items():
printable_results[cat_name] = {
"num": int(cat_results["num_example"]),
"mmmu_acc": round(cat_results["mmmu_acc"], 3),
}
all_ins_acc = calculate_ins_level_acc(evaluation_result)
printable_results["Overall"] = {
"num": sum([cat_results["num_example"] for cat_results in evaluation_result.values()]),
"mmmu_acc": round(all_ins_acc, 3),
}
print(printable_results)
return printable_results["Overall"]
##################
# Helper functions written by official MMMU repo.
##################
def calculate_ins_level_acc(results):
"""Calculate the instruction level accuracy for given Subject results
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L246
"""
acc = 0
ins_num = 0
for cat_results in results.values():
acc += cat_results["mmmu_acc"] * cat_results["num_example"]
ins_num += cat_results["num_example"]
if ins_num == 0:
return 0
return acc / ins_num
DOMAIN_CAT2SUB_CAT = {
"Art and Design": ["Art", "Art_Theory", "Design", "Music"],
"Business": ["Accounting", "Economics", "Finance", "Manage", "Marketing"],
"Science": [
"Biology",
"Chemistry",
"Geography",
"Math",
"Physics",
],
"Health and Medicine": [
"Basic_Medical_Science",
"Clinical_Medicine",
"Diagnostics_and_Laboratory_Medicine",
"Pharmacy",
"Public_Health",
],
"Humanities and Social Science": [
"History",
"Literature",
"Sociology",
"Psychology",
],
"Tech and Engineering": [
"Agriculture",
"Architecture_and_Engineering",
"Computer_Science",
"Electronics",
"Energy_and_Power",
"Materials",
"Mechanical_Engineering",
],
}
def eval_multi_choice(gold_i, pred_i):
"""
Evaluate a multiple choice instance.
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L175
"""
correct = False
# only they are exactly the same, we consider it as correct
if isinstance(gold_i, list):
for answer in gold_i:
if answer == pred_i:
correct = True
break
else: # gold_i is a string
if gold_i == pred_i:
correct = True
return correct
def eval_open(gold_i, pred_i):
"""
Evaluate an open question instance
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L191
"""
correct = False
if isinstance(gold_i, list):
# use float to avoid trivial matches
norm_answers = []
for answer in gold_i:
norm_answers.extend(normalize_str(answer))
else:
norm_answers = normalize_str(gold_i)
for pred in pred_i: # pred is already normalized in parse response phase
if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
for norm_ans in norm_answers:
# only see if the string answer in the string pred
if isinstance(norm_ans, str) and norm_ans in pred:
if not correct:
correct = True
break
else: # it's a float number
if pred in norm_answers:
if not correct:
correct = True
break
return correct
def evaluate_mmmu(samples):
"""
Batch evaluation for multiple choice and open questions.
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L219
"""
pred_correct = 0
judge_dict = dict()
for sample in samples:
gold_i = sample["answer"]
pred_i = sample["parsed_pred"]
if sample["question_type"] == "multiple-choice":
correct = eval_multi_choice(gold_i, pred_i)
else: # open question
correct = eval_open(gold_i, pred_i)
if correct:
judge_dict[sample["id"]] = "Correct"
pred_correct += 1
else:
judge_dict[sample["id"]] = "Wrong"
if len(samples) == 0:
return {"mmmu_acc": 0}
return judge_dict, {"mmmu_acc": pred_correct / len(samples)}
def parse_multi_choice_response(response, all_choices, index2ans):
"""
Parse the prediction from the generated response.
Return the predicted index e.g., A, B, C, D.
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L10
"""
for char in [",", ".", "!", "?", ";", ":", "'"]:
response = response.strip(char)
response = " " + response + " " # add space to avoid partial match
index_ans = True
ans_with_brack = False
candidates = []
for choice in all_choices: # e.g., (A) (B) (C) (D)
if f"({choice})" in response:
candidates.append(choice)
ans_with_brack = True
if len(candidates) == 0:
for choice in all_choices: # e.g., A B C D
if f"{choice} " in response:
candidates.append(choice)
if len(candidates) == 0:
for choice in all_choices: # e.g., A. B. C. D.
if f"{choice}." in response:
candidates.append(choice)
# if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
if len(candidates) == 0 and len(response.split()) > 5:
for index, ans in index2ans.items():
if ans.lower() in response.lower():
candidates.append(index)
index_ans = False # it's content ans.
if len(candidates) == 0: # still not get answer, randomly choose one.
pred_index = random.choice(all_choices)
elif len(candidates) > 1:
start_indexes = []
if index_ans:
if ans_with_brack:
for can in candidates:
index = response.rfind(f"({can})")
start_indexes.append(index) # -1 will be ignored anyway
# start_indexes = [generated_response.index(f'({can})') for can in candidates]
else:
for can in candidates:
index = response.rfind(f" {can} ")
start_indexes.append(index)
else:
for can in candidates:
index = response.lower().rfind(index2ans[can].lower())
start_indexes.append(index)
# get the last one
pred_index = candidates[np.argmax(start_indexes)]
else: # if only one candidate, use it.
pred_index = candidates[0]
return pred_index
def extract_numbers(string):
"""
Exact all forms of numbers from a string with regex.
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L100
"""
# Pattern for numbers with commas
pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b"
# Pattern for scientific notation
pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
# Pattern for simple numbers without commas
pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"
# Extract numbers with commas
numbers_with_commas = re.findall(pattern_commas, string)
# Extract numbers in scientific notation
numbers_scientific = re.findall(pattern_scientific, string)
# Extract simple numbers without commas
numbers_simple = re.findall(pattern_simple, string)
# Combine all extracted numbersz
all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
return all_numbers
def check_is_number(string):
"""
Check if the given string a number.
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L65
"""
try:
float(string.replace(",", ""))
return True
except ValueError:
# check if there's comma inside
return False
def normalize_str(string):
"""
Normalize the str to lower case and make them float numbers if possible.
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L76
"""
# check if characters in the string
# if number, numerize it.
string = string.strip()
is_number = check_is_number(string)
if is_number:
string = string.replace(",", "")
string = float(string)
# leave 2 decimal
string = round(string, 2)
return [string]
else: # it's likely to be a string
# lower it
string = string.lower()
if len(string) == 1:
return [" " + string, string + " "] # avoid trivial matches
return [string]
def parse_open_response(response):
"""
Parse the prediction from the generated response.
Return a list of predicted strings or numbers.
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L122
"""
# content = content.strip("\n").strip(".").strip(" ")
def get_key_subresponses(response):
key_responses = []
response = response.strip().strip(".").lower()
sub_responses = re.split(r"\.\s(?=[A-Z])|\n", response)
indicators_of_keys = [
"could be ",
"so ",
"is ",
"thus ",
"therefore ",
"final ",
"answer ",
"result ",
]
key_responses = []
for index, resp in enumerate(sub_responses):
# if last one, accept it's an equation (the entire response can be just one sentence with equation)
if index == len(sub_responses) - 1:
indicators_of_keys.extend(["="])
shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
for indicator in indicators_of_keys:
if indicator in resp:
if not shortest_key_response:
shortest_key_response = resp.split(indicator)[-1].strip()
else:
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
shortest_key_response = resp.split(indicator)[-1].strip()
# key_responses.append(resp.split(indicator)[1].strip())
if shortest_key_response:
# and it's not trivial
if shortest_key_response.strip() not in [
":",
",",
".",
"!",
"?",
";",
":",
"'",
]:
key_responses.append(shortest_key_response)
if len(key_responses) == 0: # did not found any
return [response]
return key_responses
# pdb.set_trace()
key_responses = get_key_subresponses(response)
pred_list = key_responses.copy() # keep the original string response
for resp in key_responses:
pred_list.extend(extract_numbers(resp))
tmp_pred_list = []
for i in range(len(pred_list)):
tmp_pred_list.extend(normalize_str(pred_list[i]))
pred_list = tmp_pred_list
# remove duplicates
pred_list = list(set(pred_list))
return pred_list
def get_multi_choice_info(options):
"""
Given the list of options for multiple choice question
Return the index2ans and all_choices
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/data_utils.py#L54
"""
start_chr = "A"
all_choices = []
index2ans = {}
for i, option in enumerate(options):
index2ans[chr(ord(start_chr) + i)] = option
all_choices.append(chr(ord(start_chr) + i))
return index2ans, all_choices
\ No newline at end of file
......@@ -10,7 +10,7 @@ import os
import re
from dataclasses import asdict, is_dataclass
from itertools import islice
from typing import Any, Callable, List
from typing import Any, Callable, List, Union, Tuple, Iterable, Optional, Iterator
import numpy as np
import yaml
......@@ -364,7 +364,11 @@ def make_table(result_dict, column: str = "results", sort_results: bool = True):
se = dic[m + "_stderr" + "," + f]
if se != "N/A":
se = "%.4f" % se
values.append([k, version, f, n, m, hib, "%.4f" % v, "±", se])
if type(v) is dict:
for v_key, v_v in v.items():
values.append([k, version, f, n, m + "_" + v_key, "%.4f" % v_v, "±", se])
else:
values.append([k, version, f, n, m, hib, "%.4f" % v, "±", se])
else:
values.append([k, version, f, n, m, hib, "%.4f" % v, "", ""])
k = ""
......@@ -485,3 +489,160 @@ def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
among ranks in multigpu setting or only pulling a sample of documents
"""
return islice(raw_iterator, rank, limit, world_size)
class Collator:
"""
A class for reordering and batching elements of an array.
This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
"""
def __init__(
self,
arr: List,
sort_fn: Callable,
group_fn: Callable = lambda x: x[1],
grouping: bool = False,
) -> None:
self.grouping = grouping
self.fn = sort_fn
self.group_fn = lambda x: group_fn(x[1]) # first index are enumerated indices
self.reorder_indices: List = []
self.size = len(arr)
self.arr_with_indices: Iterable[Any] = tuple(enumerate(arr)) # [indices, (arr)]
if self.grouping is True:
self.group_by_index()
def group_by_index(self) -> None:
self.arr_with_indices = self.group(self.arr_with_indices, fn=self.group_fn, values=False)
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
"""
Generates and yields batches from the reordered array.
Parameters:
- n (int): The size of each batch. Defaults to 1.
- batch_fn (Optional[Callable[[int, Iterable], int]]): A function to determine the size of each batch. Defaults to None.
Yields:
Iterator: An iterator over batches of reordered elements.
"""
if self.grouping:
for (
key,
values,
) in self.arr_with_indices.items(): # type: ignore
values = self._reorder(values)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
else:
values = self._reorder(self.arr_with_indices) # type: ignore
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> List:
"""
Reorders the elements in the array based on the sorting function.
Parameters:
- arr (Union[List, Tuple[Tuple[int, Any], ...]]): The array or iterable to be reordered.
Yields:
List: Yields reordered elements one by one.
"""
arr = sorted(arr, key=lambda x: self.fn(x[1]))
self.reorder_indices.extend([x[0] for x in arr])
yield from [x[1] for x in arr]
def get_original(self, newarr: List) -> List:
"""
Restores the original order of elements from the reordered list.
Parameters:
- newarr (List): The reordered array.
Returns:
List: The array with elements restored to their original order.
"""
res = [None] * self.size
cov = [False] * self.size
for ind, v in zip(self.reorder_indices, newarr):
res[ind] = v
cov[ind] = True
assert all(cov)
return res
def __len__(self):
return self.size
@staticmethod
def group(arr: Iterable, fn: Callable, values: bool = False) -> Iterable:
"""
Groups elements of an iterable based on a provided function.
Parameters:
- arr (Iterable): The iterable to be grouped.
- fn (Callable): The function to determine the grouping.
- values (bool): If True, returns the values of the group. Defaults to False.
Returns:
Iterable: An iterable of grouped elements.
"""
res = collections.defaultdict(list)
for ob in arr:
try:
hashable_dict = tuple(
(
key,
tuple(value) if isinstance(value, collections.abc.Iterable) else value,
)
for key, value in sorted(fn(ob).items())
)
res[hashable_dict].append(ob)
except TypeError:
res[fn(ob)].append(ob)
if not values:
return res
return res.values()
@staticmethod
def get_chunks(_iter, n: int = 0, fn=None):
"""
Divides an iterable into chunks of specified size or based on a given function.
Useful for batching
Parameters:
- iter: The input iterable to be divided into chunks.
- n: An integer representing the size of each chunk. Default is 0.
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
Returns:
An iterator that yields chunks of the input iterable.
Example usage:
```
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for chunk in chunks(data, 3):
print(chunk)
```
Output:
```
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10]
```
"""
arr = []
_iter = tuple(_iter)
for i, x in enumerate(_iter):
arr.append(x)
if len(arr) == (fn(i, _iter) if fn else n):
yield arr
arr = []
if arr:
yield arr
\ No newline at end of file
......@@ -66,6 +66,7 @@ ifeval = ["langdetect", "immutabledict"]
neuronx = ["optimum[neuronx]"]
mamba = ["mamba_ssm", "causal-conv1d==1.0.2"]
math = ["sympy>=1.12", "antlr4-python3-runtime==4.11"]
mllm = ["transformers >= 4.40.0", "llava-torch==1.1.1"]
multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
openai = ["openai==1.3.9", "tiktoken"]
optimum = ["optimum[openvino]"]
......@@ -86,6 +87,7 @@ all = [
"lm_eval[ifeval]",
"lm_eval[mamba]",
"lm_eval[math]",
"lm_eval[mllm]",
"lm_eval[multilingual]",
"lm_eval[openai]",
"lm_eval[promptsource]",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment