Commit 7d7a3a1c authored by Ashvin Nihalani's avatar Ashvin Nihalani
Browse files

More Format Changes

parent 2dc436fa
...@@ -6,9 +6,7 @@ OutputType = Literal[ ...@@ -6,9 +6,7 @@ OutputType = Literal[
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice" "loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
] ]
InputType = Literal[ InputType = Literal["text", "text_image"]
"text", "text_image"
]
@dataclass @dataclass
...@@ -32,7 +30,6 @@ class Instance: ...@@ -32,7 +30,6 @@ class Instance:
doc_id: Optional[int] = None doc_id: Optional[int] = None
repeats: Optional[int] = None repeats: Optional[int] = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
# unpack metadata field # unpack metadata field
self.task_name, self.doc_id, self.repeats = self.metadata self.task_name, self.doc_id, self.repeats = self.metadata
......
...@@ -1278,7 +1278,7 @@ class ConfigurableTask(Task): ...@@ -1278,7 +1278,7 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_visual(self, doc:dict) -> Union[int, str, list]: def doc_to_visual(self, doc: dict) -> Union[int, str, list]:
if isinstance(self.config.doc_to_visual, str): if isinstance(self.config.doc_to_visual, str):
assert self.config.doc_to_visual in self.features assert self.config.doc_to_visual in self.features
# Single Image. Still return a list for consistency # Single Image. Still return a list for consistency
...@@ -1341,12 +1341,23 @@ class ConfigurableTask(Task): ...@@ -1341,12 +1341,23 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "generate_until": elif self.OUTPUT_TYPE == "generate_until":
if self.INPUT_TYPE == "text_image": if self.INPUT_TYPE == "text_image":
arguments = (ctx, deepcopy(self.config.generation_kwargs), self.doc_to_visual, doc, self.config.task) arguments = (
ctx,
deepcopy(self.config.generation_kwargs),
self.doc_to_visual,
doc,
self.config.task,
)
elif self.INPUT_TYPE == "text": elif self.INPUT_TYPE == "text":
arguments = (ctx, deepcopy(self.config.generation_kwargs)) arguments = (ctx, deepcopy(self.config.generation_kwargs))
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, input_type=self.INPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs request_type=self.OUTPUT_TYPE,
input_type=self.INPUT_TYPE,
doc=doc,
arguments=arguments,
idx=0,
**kwargs,
) )
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -1556,7 +1567,7 @@ class ConfigurableTask(Task): ...@@ -1556,7 +1567,7 @@ class ConfigurableTask(Task):
f"output_type={self.OUTPUT_TYPE}," f"output_type={self.OUTPUT_TYPE},"
f"input_type={self.INPUT_TYPE}", f"input_type={self.INPUT_TYPE}",
f"num_fewshot={getattr(self.config, 'num_fewshot', None)}," f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
f"num_samples={len(self.eval_docs)})" f"num_samples={len(self.eval_docs)})",
) )
......
...@@ -607,16 +607,16 @@ def evaluate( ...@@ -607,16 +607,16 @@ def evaluate(
] ]
# compute group's pooled metric and stderr # compute group's pooled metric and stderr
results[group][ results[group][metric] = (
metric lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes) )
# TODO: calculate grouped metric using aggregation fn # TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs: if "N/A" in stderrs:
results[group][stderr] = "N/A" results[group][stderr] = "N/A"
else: else:
results[group][ results[group][stderr] = (
stderr lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes) )
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility # TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line: # To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics) # results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
......
...@@ -275,9 +275,9 @@ def consolidate_results( ...@@ -275,9 +275,9 @@ def consolidate_results(
metric_key metric_key
] ]
results[task_output.task_name]["samples"] = task_output.sample_len results[task_output.task_name]["samples"] = task_output.sample_len
results[task_output.task_name][ results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = (
f"{metric}_stderr,{filter_key}" task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"] )
return results, samples, configs, versions, num_fewshot, higher_is_better return results, samples, configs, versions, num_fewshot, higher_is_better
......
...@@ -40,19 +40,19 @@ class Llava(LM): ...@@ -40,19 +40,19 @@ class Llava(LM):
""" """
def __init__( def __init__(
self, self,
pretrained: str = "liuhaotian/llava-v1.5-7b", pretrained: str = "liuhaotian/llava-v1.5-7b",
truncation: Optional[bool] = True, truncation: Optional[bool] = True,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: Optional[Union[str, torch.dtype]] = "auto",
batch_size: Optional[Union[int, str]] = 1, batch_size: Optional[Union[int, str]] = 1,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
revision=None, revision=None,
use_flash_attention_2=False, use_flash_attention_2=False,
conv_template="vicuna_v1", conv_template="vicuna_v1",
use_cache=True, use_cache=True,
truncate_context=False, # whether to truncate the context in generation, set it False for LLaVA-1.6 truncate_context=False, # whether to truncate the context in generation, set it False for LLaVA-1.6
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
# Do not use kwargs for now # Do not use kwargs for now
...@@ -68,7 +68,12 @@ class Llava(LM): ...@@ -68,7 +68,12 @@ class Llava(LM):
self._model, self._model,
self._image_processor, self._image_processor,
self._max_length, self._max_length,
) = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self._device) ) = load_pretrained_model(
pretrained,
None,
get_model_name_from_path(pretrained),
device_map=self._device,
)
self._config = self._model.config self._config = self._model.config
self.model.eval() self.model.eval()
self.model.tie_weights() self.model.tie_weights()
...@@ -79,26 +84,40 @@ class Llava(LM): ...@@ -79,26 +84,40 @@ class Llava(LM):
self.truncate_context = truncate_context self.truncate_context = truncate_context
# assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, assert accelerator.distributed_type in [
DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." DistributedType.FSDP,
DistributedType.MULTI_GPU,
DistributedType.DEEPSPEED,
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
if accelerator.distributed_type == DistributedType.DEEPSPEED: if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = { kwargs = {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu, "train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, "train_batch_size": self.batch_size_per_gpu
* accelerator.num_processes,
} }
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) AcceleratorState().deepspeed_plugin.deepspeed_config_process(
must_match=True, **kwargs
)
eval_logger.info( eval_logger.info(
"Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") "Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0"
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: )
if (
accelerator.distributed_type == DistributedType.FSDP
or accelerator.distributed_type == DistributedType.DEEPSPEED
):
self._model = accelerator.prepare(self.model) self._model = accelerator.prepare(self.model)
else: else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True) self._model = accelerator.prepare_model(
self.model, evaluation_mode=True
)
self.accelerator = accelerator self.accelerator = accelerator
if self.accelerator.is_local_main_process: if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") eval_logger.info(
f"Using {accelerator.num_processes} devices with data parallelism"
)
self._rank = self.accelerator.local_process_index self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes self._world_size = self.accelerator.num_processes
else: else:
...@@ -135,7 +154,9 @@ class Llava(LM): ...@@ -135,7 +154,9 @@ class Llava(LM):
def pad_sequence(self, input_ids, batch_first, padding_value): def pad_sequence(self, input_ids, batch_first, padding_value):
if self.tokenizer.padding_side == "left": if self.tokenizer.padding_side == "left":
input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids] input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value) input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=batch_first, padding_value=padding_value
)
if self.tokenizer.padding_side == "left": if self.tokenizer.padding_side == "left":
input_ids = torch.flip(input_ids, [1]) input_ids = torch.flip(input_ids, [1])
return input_ids return input_ids
...@@ -156,7 +177,9 @@ class Llava(LM): ...@@ -156,7 +177,9 @@ class Llava(LM):
def world_size(self): def world_size(self):
return self._world_size return self._world_size
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: def tok_encode(
self, string: str, left_truncate_len=None, add_special_tokens=None
) -> List[int]:
""" """ """ """
add_special_tokens = False if add_special_tokens is None else add_special_tokens add_special_tokens = False if add_special_tokens is None else add_special_tokens
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
...@@ -171,9 +194,13 @@ class Llava(LM): ...@@ -171,9 +194,13 @@ class Llava(LM):
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
# TODO # TODO
res = [] res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") pbar = tqdm(
total=len(requests), disable=(self.rank != 0), desc="Model Responding"
)
for contexts, doc_to_target, doc_to_visual, doc, task in [reg.args for reg in requests]: for contexts, doc_to_target, doc_to_visual, doc, task in [
reg.args for reg in requests
]:
# encode, pad, and truncate contexts for this batch # encode, pad, and truncate contexts for this batch
if isinstance(doc_to_target, str): if isinstance(doc_to_target, str):
continuation = doc_to_target continuation = doc_to_target
...@@ -184,7 +211,10 @@ class Llava(LM): ...@@ -184,7 +211,10 @@ class Llava(LM):
if visuals: if visuals:
image = process_images(visuals, self._image_processor, self._config) image = process_images(visuals, self._image_processor, self._config)
if isinstance(image, list): if isinstance(image, list):
image = [_image.to(dtype=torch.float16, device=self.device) for _image in image] image = [
_image.to(dtype=torch.float16, device=self.device)
for _image in image
]
else: else:
image = image.to(dtype=torch.float16, device=self.device) image = image.to(dtype=torch.float16, device=self.device)
else: else:
...@@ -192,7 +222,11 @@ class Llava(LM): ...@@ -192,7 +222,11 @@ class Llava(LM):
prompts_input = contexts[0] prompts_input = contexts[0]
if image is not None and len(image) != 0 and DEFAULT_IMAGE_TOKEN not in prompts_input: if (
image is not None
and len(image) != 0
and DEFAULT_IMAGE_TOKEN not in prompts_input
):
""" """
Three senarios: Three senarios:
1. No image, and there for, no image token should be added. 1. No image, and there for, no image token should be added.
...@@ -207,32 +241,48 @@ class Llava(LM): ...@@ -207,32 +241,48 @@ class Llava(LM):
conv.append_message(conv.roles[0], prompts_input) conv.append_message(conv.roles[0], prompts_input)
conv.append_message(conv.roles[1], None) conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt() prompt = conv.get_prompt()
contxt_id = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze( contxt_id = (
0).to(self.device) tokenizer_image_token(
prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
)
.unsqueeze(0)
.to(self.device)
)
# Add the answer of the second role # Add the answer of the second role
conv.messages[1][1] = continuation conv.messages[1][1] = continuation
prompt = conv.get_prompt() prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze( input_ids = (
0).to(self.device) tokenizer_image_token(
prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
)
.unsqueeze(0)
.to(self.device)
)
labels = input_ids.clone() labels = input_ids.clone()
# Context part no need to calculate for loss # Context part no need to calculate for loss
labels[0, : contxt_id.shape[1]] = -100 labels[0, : contxt_id.shape[1]] = -100
with torch.inference_mode(): with torch.inference_mode():
outputs = self.model(input_ids=input_ids, labels=labels, images=image, use_cache=True) outputs = self.model(
input_ids=input_ids, labels=labels, images=image, use_cache=True
)
loss = outputs["loss"] loss = outputs["loss"]
# loss = torch.exp(loss) # loss = torch.exp(loss)
logits = outputs["logits"] logits = outputs["logits"]
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
cont_toks = input_ids[:, contxt_id.shape[1]:] # [1, seq] cont_toks = input_ids[:, contxt_id.shape[1] :] # [1, seq]
greedy_tokens = greedy_tokens[:, contxt_id.shape[1]: input_ids.shape[1]] # [1, seq] greedy_tokens = greedy_tokens[
:, contxt_id.shape[1] : input_ids.shape[1]
] # [1, seq]
max_equal = (greedy_tokens == cont_toks).all() max_equal = (greedy_tokens == cont_toks).all()
res.append((float(loss.item()), bool(max_equal))) res.append((float(loss.item()), bool(max_equal)))
pbar.update(1) pbar.update(1)
pbar.close() pbar.close()
return res return res
def loglikelihood_rolling(self, requests: List[Instance]) -> List[Tuple[float, bool]]: def loglikelihood_rolling(
self, requests: List[Instance]
) -> List[Tuple[float, bool]]:
raise NotImplementedError() raise NotImplementedError()
def flatten(self, input): def flatten(self, input):
...@@ -265,8 +315,11 @@ class Llava(LM): ...@@ -265,8 +315,11 @@ class Llava(LM):
group_fn=lambda x: x[1], group_fn=lambda x: x[1],
) )
chunks = re_ords.get_batched(n=self.batch_size) chunks = re_ords.get_batched(n=self.batch_size)
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len( num_iters = (
requests) // self.batch_size + 1 len(requests) // self.batch_size
if len(requests) % self.batch_size == 0
else len(requests) // self.batch_size + 1
)
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs, doc_to_visual, doc, task = zip(*chunk) contexts, all_gen_kwargs, doc_to_visual, doc, task = zip(*chunk)
...@@ -288,19 +341,32 @@ class Llava(LM): ...@@ -288,19 +341,32 @@ class Llava(LM):
until = [until] until = [until]
elif not isinstance(until, list): elif not isinstance(until, list):
raise ValueError( raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}") f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}"
)
if "image_aspect_ratio" in gen_kwargs.keys() and "image_aspect_ratio" not in self._config.__dict__: if (
"image_aspect_ratio" in gen_kwargs.keys()
and "image_aspect_ratio" not in self._config.__dict__
):
# here we should pop it out of gen_kwargs so that it doesn't get passed to the model for next step of generation # here we should pop it out of gen_kwargs so that it doesn't get passed to the model for next step of generation
self._config.image_aspect_ratio = gen_kwargs.pop("image_aspect_ratio") self._config.image_aspect_ratio = gen_kwargs.pop("image_aspect_ratio")
eval_logger.info(f"Setting image aspect ratio: {self._config.image_aspect_ratio}") eval_logger.info(
f"Setting image aspect ratio: {self._config.image_aspect_ratio}"
)
# encode, pad, and truncate contexts for this batch # encode, pad, and truncate contexts for this batch
if visuals: if visuals:
image_tensor = process_images(visuals, self._image_processor, self._config) image_tensor = process_images(
visuals, self._image_processor, self._config
)
if isinstance(image_tensor, list): if isinstance(image_tensor, list):
image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor] image_tensor = [
_image.to(dtype=torch.float16, device=self.device)
for _image in image_tensor
]
else: else:
image_tensor = image_tensor.to(dtype=torch.float16, device=self.device) image_tensor = image_tensor.to(
dtype=torch.float16, device=self.device
)
else: else:
image_tensor = None image_tensor = None
...@@ -309,15 +375,22 @@ class Llava(LM): ...@@ -309,15 +375,22 @@ class Llava(LM):
question_input = [] question_input = []
for visual, context in zip(visuals, contexts): for visual, context in zip(visuals, contexts):
if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context: if (
image_tensor is not None
and len(image_tensor) != 0
and DEFAULT_IMAGE_TOKEN not in context
):
""" """
Three senarios: Three senarios:
1. No image, and there for, no image token should be added. 1. No image, and there for, no image token should be added.
2. image token is already specified in the context, so we don't need to add it. 2. image token is already specified in the context, so we don't need to add it.
3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line. 3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line.
""" """
image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visual) if isinstance(visual, list) else [ image_tokens = (
DEFAULT_IMAGE_TOKEN] [DEFAULT_IMAGE_TOKEN] * len(visual)
if isinstance(visual, list)
else [DEFAULT_IMAGE_TOKEN]
)
image_tokens = " ".join(image_tokens) image_tokens = " ".join(image_tokens)
question = image_tokens + "\n" + context question = image_tokens + "\n" + context
else: else:
...@@ -343,7 +416,9 @@ class Llava(LM): ...@@ -343,7 +416,9 @@ class Llava(LM):
# input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device) # input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
# preconfigure gen_kwargs with defaults # preconfigure gen_kwargs with defaults
gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] gen_kwargs["image_sizes"] = [
visuals[idx].size for idx in range(len(visuals))
]
if "max_gen_toks" not in gen_kwargs: if "max_gen_toks" not in gen_kwargs:
gen_kwargs["max_gen_toks"] = 1024 gen_kwargs["max_gen_toks"] = 1024
if "temperature" not in gen_kwargs: if "temperature" not in gen_kwargs:
...@@ -353,10 +428,20 @@ class Llava(LM): ...@@ -353,10 +428,20 @@ class Llava(LM):
if "num_beams" not in gen_kwargs: if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1 gen_kwargs["num_beams"] = 1
input_ids_list = [tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") for input_ids_list = [
prompt in question_input] tokenizer_image_token(
pad_token_ids = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
input_ids = self.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_token_ids).to(self.device) )
for prompt in question_input
]
pad_token_ids = (
self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None
else self.tokenizer.eos_token_id
)
input_ids = self.pad_sequence(
input_ids_list, batch_first=True, padding_value=pad_token_ids
).to(self.device)
attention_masks = input_ids.ne(pad_token_ids).to(self.device) attention_masks = input_ids.ne(pad_token_ids).to(self.device)
# These steps are not in LLaVA's original code, but are necessary for generation to work # These steps are not in LLaVA's original code, but are necessary for generation to work
# TODO: pay attention to this major generation step... # TODO: pay attention to this major generation step...
...@@ -374,7 +459,9 @@ class Llava(LM): ...@@ -374,7 +459,9 @@ class Llava(LM):
max_new_tokens=gen_kwargs["max_gen_toks"], max_new_tokens=gen_kwargs["max_gen_toks"],
use_cache=self.use_cache, use_cache=self.use_cache,
) )
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True) text_outputs = self.tokenizer.batch_decode(
cont, skip_special_tokens=True
)
except Exception as e: except Exception as e:
eval_logger.error(f"Error {e} in generating") eval_logger.error(f"Error {e} in generating")
cont = "" cont = ""
...@@ -393,10 +480,12 @@ class Llava(LM): ...@@ -393,10 +480,12 @@ class Llava(LM):
# # for seq2seq case where self.tok_decode(self.eot_token_id) = '' # # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
# text_outputs = text_outputs.split(term)[0] # text_outputs = text_outputs.split(term)[0]
res.extend(text_outputs) res.extend(text_outputs)
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), text_outputs
)
pbar.update(1) pbar.update(1)
# reorder this group of results back to original unsorted form # reorder this group of results back to original unsorted form
res = re_ords.get_original(res) res = re_ords.get_original(res)
pbar.close() pbar.close()
return res return res
\ No newline at end of file
...@@ -24,7 +24,12 @@ def replace_images_tokens(input_string): ...@@ -24,7 +24,12 @@ def replace_images_tokens(input_string):
def parse_options(options): def parse_options(options):
option_letters = [chr(ord("A") + i) for i in range(len(options))] option_letters = [chr(ord("A") + i) for i in range(len(options))]
choices_str = "\n".join([f"{option_letter}. {option}" for option_letter, option in zip(option_letters, options)]) choices_str = "\n".join(
[
f"{option_letter}. {option}"
for option_letter, option in zip(option_letters, options)
]
)
return choices_str return choices_str
...@@ -49,7 +54,9 @@ def mmmu_doc_to_visual(doc): ...@@ -49,7 +54,9 @@ def mmmu_doc_to_visual(doc):
prompt = construct_prompt(doc) prompt = construct_prompt(doc)
image_tokens = re.findall(r"<image \d+>", prompt) image_tokens = re.findall(r"<image \d+>", prompt)
# Remove <> and swap space as _ # Remove <> and swap space as _
image_tokens = [image_token.strip("<>").replace(" ", "_") for image_token in image_tokens] image_tokens = [
image_token.strip("<>").replace(" ", "_") for image_token in image_tokens
]
visual = [doc[image_token].convert("RGB") for image_token in image_tokens] visual = [doc[image_token].convert("RGB") for image_token in image_tokens]
return visual return visual
...@@ -62,10 +69,14 @@ def mmmu_process_results(doc, results): ...@@ -62,10 +69,14 @@ def mmmu_process_results(doc, results):
else: else:
parsed_pred = parse_open_response(pred) parsed_pred = parse_open_response(pred)
id = doc["id"] id = doc["id"]
mmmu_acc = {"id": id, "subdomain": extract_subset_name(doc["id"]), "question_type": doc["question_type"], "answer": doc["answer"], "parsed_pred": parsed_pred} mmmu_acc = {
return { "id": id,
"mmmu_acc": mmmu_acc "subdomain": extract_subset_name(doc["id"]),
"question_type": doc["question_type"],
"answer": doc["answer"],
"parsed_pred": parsed_pred,
} }
return {"mmmu_acc": mmmu_acc}
def extract_subset_name(input_string): def extract_subset_name(input_string):
...@@ -97,7 +108,12 @@ def mmmu_aggregate_results(results): ...@@ -97,7 +108,12 @@ def mmmu_aggregate_results(results):
else: else:
pass pass
in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results) in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
in_domain_data_num = sum([cat_results["num_example"] for cat_results in in_domain_cat_results.values()]) in_domain_data_num = sum(
[
cat_results["num_example"]
for cat_results in in_domain_cat_results.values()
]
)
printable_results["Overall-" + domain] = { printable_results["Overall-" + domain] = {
"num": int(in_domain_data_num), "num": int(in_domain_data_num),
"mmmu_acc": round(in_domain_ins_acc, 3), "mmmu_acc": round(in_domain_ins_acc, 3),
...@@ -110,7 +126,9 @@ def mmmu_aggregate_results(results): ...@@ -110,7 +126,9 @@ def mmmu_aggregate_results(results):
} }
all_ins_acc = calculate_ins_level_acc(evaluation_result) all_ins_acc = calculate_ins_level_acc(evaluation_result)
printable_results["Overall"] = { printable_results["Overall"] = {
"num": sum([cat_results["num_example"] for cat_results in evaluation_result.values()]), "num": sum(
[cat_results["num_example"] for cat_results in evaluation_result.values()]
),
"mmmu_acc": round(all_ins_acc, 3), "mmmu_acc": round(all_ins_acc, 3),
} }
print(printable_results) print(printable_results)
...@@ -401,7 +419,9 @@ def parse_open_response(response): ...@@ -401,7 +419,9 @@ def parse_open_response(response):
if not shortest_key_response: if not shortest_key_response:
shortest_key_response = resp.split(indicator)[-1].strip() shortest_key_response = resp.split(indicator)[-1].strip()
else: else:
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response): if len(resp.split(indicator)[-1].strip()) < len(
shortest_key_response
):
shortest_key_response = resp.split(indicator)[-1].strip() shortest_key_response = resp.split(indicator)[-1].strip()
# key_responses.append(resp.split(indicator)[1].strip()) # key_responses.append(resp.split(indicator)[1].strip())
...@@ -454,4 +474,4 @@ def get_multi_choice_info(options): ...@@ -454,4 +474,4 @@ def get_multi_choice_info(options):
index2ans[chr(ord(start_chr) + i)] = option index2ans[chr(ord(start_chr) + i)] = option
all_choices.append(chr(ord(start_chr) + i)) all_choices.append(chr(ord(start_chr) + i))
return index2ans, all_choices return index2ans, all_choices
\ No newline at end of file
...@@ -366,7 +366,9 @@ def make_table(result_dict, column: str = "results", sort_results: bool = True): ...@@ -366,7 +366,9 @@ def make_table(result_dict, column: str = "results", sort_results: bool = True):
se = "%.4f" % se se = "%.4f" % se
if isinstance(v, dict): if isinstance(v, dict):
for v_key, v_v in v.items(): for v_key, v_v in v.items():
values.append([k, version, f, n, m + "_" + v_key, "%.4f" % v_v, "±", se]) values.append(
[k, version, f, n, m + "_" + v_key, "%.4f" % v_v, "±", se]
)
else: else:
values.append([k, version, f, n, m, hib, "%.4f" % v, "±", se]) values.append([k, version, f, n, m, hib, "%.4f" % v, "±", se])
else: else:
...@@ -490,6 +492,7 @@ def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None): ...@@ -490,6 +492,7 @@ def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
""" """
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
class Collator: class Collator:
""" """
A class for reordering and batching elements of an array. A class for reordering and batching elements of an array.
...@@ -514,7 +517,9 @@ class Collator: ...@@ -514,7 +517,9 @@ class Collator:
self.group_by_index() self.group_by_index()
def group_by_index(self) -> None: def group_by_index(self) -> None:
self.arr_with_indices = self.group(self.arr_with_indices, fn=self.group_fn, values=False) self.arr_with_indices = self.group(
self.arr_with_indices, fn=self.group_fn, values=False
)
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator: def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
""" """
...@@ -597,7 +602,9 @@ class Collator: ...@@ -597,7 +602,9 @@ class Collator:
hashable_dict = tuple( hashable_dict = tuple(
( (
key, key,
tuple(value) if isinstance(value, collections.abc.Iterable) else value, tuple(value)
if isinstance(value, collections.abc.Iterable)
else value,
) )
for key, value in sorted(fn(ob).items()) for key, value in sorted(fn(ob).items())
) )
...@@ -645,4 +652,4 @@ class Collator: ...@@ -645,4 +652,4 @@ class Collator:
arr = [] arr = []
if arr: if arr:
yield arr yield arr
\ No newline at end of file
...@@ -66,7 +66,7 @@ ifeval = ["langdetect", "immutabledict"] ...@@ -66,7 +66,7 @@ ifeval = ["langdetect", "immutabledict"]
neuronx = ["optimum[neuronx]"] neuronx = ["optimum[neuronx]"]
mamba = ["mamba_ssm", "causal-conv1d==1.0.2"] mamba = ["mamba_ssm", "causal-conv1d==1.0.2"]
math = ["sympy>=1.12", "antlr4-python3-runtime==4.11"] math = ["sympy>=1.12", "antlr4-python3-runtime==4.11"]
mllm = ["transformers >= 4.40.0", "llava-torch == 1.0 @ git+https://github.com/haotian-liu/LLaVA.git"] mllm = ["transformers>=4.40.0", "llava-torch==1.1.1"]
multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"] multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
openai = ["openai==1.3.9", "tiktoken"] openai = ["openai==1.3.9", "tiktoken"]
optimum = ["optimum[openvino]"] optimum = ["optimum[openvino]"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment