Commit 55f8321d authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add up-to-date code (bs>1 seems to work w/ Llava-1.6-mistral: 0.34 MMMU

parent 6e3b2ea1
...@@ -1346,7 +1346,6 @@ class ConfigurableTask(Task): ...@@ -1346,7 +1346,6 @@ class ConfigurableTask(Task):
deepcopy(self.config.generation_kwargs), deepcopy(self.config.generation_kwargs),
self.doc_to_visual, self.doc_to_visual,
doc, doc,
self.config.task,
) )
elif self.INPUT_TYPE == "text": elif self.INPUT_TYPE == "text":
arguments = (ctx, deepcopy(self.config.generation_kwargs)) arguments = (ctx, deepcopy(self.config.generation_kwargs))
......
import copy
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import transformers import transformers
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoModelForVision2Seq
from lm_eval import utils
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM from lm_eval.models.huggingface import HFLM
from lm_eval.models.utils import Collator, stop_sequences_criteria
DEFAULT_IMAGE_TOKEN = "<image>" DEFAULT_IMAGE_TOKEN = "<image>"
...@@ -19,11 +19,11 @@ class HFMultimodalLM(HFLM): ...@@ -19,11 +19,11 @@ class HFMultimodalLM(HFLM):
An abstracted Hugging Face model class for multimodal LMs like Llava and Idefics. An abstracted Hugging Face model class for multimodal LMs like Llava and Idefics.
""" """
AUTO_MODEL_CLASS = AutoModelForVision2Seq AUTO_MODEL_CLASS = transformers.AutoModelForVision2Seq
@property # @property
def max_length(self): # def max_length(self):
raise NotImplementedError # raise NotImplementedError
@property @property
def tokenizer_name(self) -> str: def tokenizer_name(self) -> str:
...@@ -194,6 +194,37 @@ class HFMultimodalLM(HFLM): ...@@ -194,6 +194,37 @@ class HFMultimodalLM(HFLM):
# def tok_decode(self, tokens, skip_special_tokens=True): # def tok_decode(self, tokens, skip_special_tokens=True):
# return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) # return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def _model_generate(self, inputs, stop, **gen_kwargs):
# TODO: handle max_length
# gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
# if "top_p" not in gen_kwargs:
# gen_kwargs["top_p"] = None
# if "num_beams" not in gen_kwargs:
# gen_kwargs["num_beams"] = 1
stopping_criteria = stop_sequences_criteria(
self.tokenizer,
stop,
inputs["input_ids"].shape[1],
inputs["input_ids"].shape[0],
)
return self.model.generate(
**inputs,
# max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=True if gen_kwargs["temperature"] > 0 else False,
temperature=gen_kwargs["temperature"],
top_p=gen_kwargs["top_p"],
num_beams=gen_kwargs["num_beams"],
max_new_tokens=gen_kwargs["max_new_tokens"],
use_cache=True,
pad_token_id=self.tokenizer.eos_token_id,
)
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
raise NotImplementedError( raise NotImplementedError(
"model type `hf-multimodal` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks" "model type `hf-multimodal` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks"
...@@ -204,14 +235,9 @@ class HFMultimodalLM(HFLM): ...@@ -204,14 +235,9 @@ class HFMultimodalLM(HFLM):
"model type `hf-multimodal` does not support loglikelihood or multiple choice. Use 'hf' model type for text-only loglikelihood tasks" "model type `hf-multimodal` does not support loglikelihood or multiple choice. Use 'hf' model type for text-only loglikelihood tasks"
) )
def flatten(self, input): def generate_until(
new_list = [] self, requests: List[Instance], disable_tqdm: bool = False
for i in input: ) -> List[str]:
for j in i:
new_list.append(j)
return new_list
def generate_until(self, requests: List[Instance]) -> List[str]:
res = [] res = []
def _collate(x): def _collate(x):
...@@ -224,47 +250,70 @@ class HFMultimodalLM(HFLM): ...@@ -224,47 +250,70 @@ class HFMultimodalLM(HFLM):
toks = self.tok_encode(x[0]) toks = self.tok_encode(x[0])
return -len(toks), x[0] return -len(toks), x[0]
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests with text+image input",
)
# TODO: port auto-batch sizing into this.
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
re_ords = utils.Collator( re_ords = Collator(
[reg.args for reg in requests], _collate, grouping=True [reg.args for reg in requests],
_collate,
group_by="gen_kwargs",
group_fn=lambda x: x[1],
) )
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
num_iters = (
len(requests) // self.batch_size ### Up to here: was identical to non-multimodal HFLM generate_until ###
if len(requests) % self.batch_size == 0
else len(requests) // self.batch_size + 1
)
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs, doc_to_visual, doc, task = zip( contexts, all_gen_kwargs, doc_to_visual, doc = zip(
*chunk *chunk
) # TODO: understand what is going on here. can we cut down on number of distinct things we pass around? ) # TODO: can we cut down further on number of distinct things we pass around?
task = task[0]
# split = split[0] visuals = [
visuals = [vis(d) for vis, d in zip(doc_to_visual, doc)] vis(d) for vis, d in zip(doc_to_visual, doc)
# visuals = self.flatten(visuals) ] # TODO: I think *fully* flattening is just wrong for bs>1 ?
### this part onward: same as HFLM ###
# we assume all gen kwargs in the batch are the same # we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it. # this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0] gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
# Set default values for until and max_new_tokens until = None
until = [self.tok_decode(self.eot_token_id)] if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# Update values from gen_kwargs if present if "until" in kwargs.keys():
if "until" in gen_kwargs: until = kwargs.pop("until")
until = gen_kwargs.pop("until")
if isinstance(until, str): if isinstance(until, str):
until = [until] until = [until]
elif not isinstance(until, list): elif not isinstance(until, list):
raise ValueError( raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}" f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
) )
assert ( else:
self.batch_size_per_gpu == 1 raise ValueError(
), "Do not support batch_size_per_gpu > 1 for now" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
context = contexts[0] )
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
### end stuff that's entirely copied verbatim from HFLM ###
max_ctx_len = self.max_length - max_gen_toks # noqa: F841 # TODO: this assumes we are using a causal LM. is that always valid? shouldn't be
# if self.accelerator.is_main_process and doc_id[0] % 100 == 0: # if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
print(f"Prompt:\n\n{contexts}\n") print(f"Prompt:\n\n{contexts}\n")
...@@ -272,37 +321,24 @@ class HFMultimodalLM(HFLM): ...@@ -272,37 +321,24 @@ class HFMultimodalLM(HFLM):
self.tokenizer.padding_side = "left" self.tokenizer.padding_side = "left"
inputs = self.processor( inputs = self.processor(
images=visuals, text=contexts, return_tensors="pt", padding=True images=visuals, text=contexts, return_tensors="pt", padding=True
).to(self._device, self.model.dtype) # TODO: ).to(
self._device, self.model.dtype
) # TODO: factor out into a tok_batch_encode bit ; truncate from left using max_ctx_len
# gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] context_enc = inputs["input_ids"]
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024 if "max_length" not in kwargs:
if "temperature" not in gen_kwargs: kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs: cont = self._model_generate(inputs, stop=until, **gen_kwargs)
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs: ### essentially same as HFLM beyond this line!
gen_kwargs["num_beams"] = 1
try:
cont = self.model.generate(
**inputs,
do_sample=True if gen_kwargs["temperature"] > 0 else False,
temperature=gen_kwargs["temperature"],
top_p=gen_kwargs["top_p"],
num_beams=gen_kwargs["num_beams"],
max_new_tokens=gen_kwargs["max_new_tokens"],
use_cache=True,
pad_token_id=self.tokenizer.eos_token_id,
)
except Exception as e:
print(f"Error {e} in generating")
cont = ""
cont_toks_list = cont.tolist() cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts): for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM # discard context + left-padding toks if using causal decoder-only LM
# if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: # TODO: ensure this holds for VLMs # if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: # TODO: ensure this holds for VLMs
cont_toks = cont_toks[inputs["input_ids"].shape[1] :] cont_toks = cont_toks[context_enc.shape[1] :]
s = self.tok_decode(cont_toks) s = self.tok_decode(cont_toks)
...@@ -313,21 +349,11 @@ class HFMultimodalLM(HFLM): ...@@ -313,21 +349,11 @@ class HFMultimodalLM(HFLM):
# for seq2seq case where self.tok_decode(self.eot_token_id) = '' # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0] s = s.split(term)[0]
if "1.5" in self.pretrained:
text_outputs = s.split("ASSISTANT:")[-1].strip()
elif "mistral" in self.pretrained:
text_outputs = s.split("[/INST]")[-1].strip()
else:
text_outputs = s.split("ASSISTANT:")[-1].strip()
# if self.accelerator.is_main_process and doc_id[0] % 100 == 0: # if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
print("hi hi") print(f"Generated text:\n\n{s}\n")
print(f"Generated text:\n\n{text_outputs}\n")
res.append(text_outputs) res.append(s)
self.cache_hook.add_partial( self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
"generate_until", (context, gen_kwargs), text_outputs
)
pbar.update(1) pbar.update(1)
# reorder this group of results back to original unsorted form # reorder this group of results back to original unsorted form
res = re_ords.get_original(res) res = re_ords.get_original(res)
......
...@@ -10,7 +10,7 @@ import os ...@@ -10,7 +10,7 @@ import os
import re import re
from dataclasses import asdict, is_dataclass from dataclasses import asdict, is_dataclass
from itertools import islice from itertools import islice
from typing import Any, Callable, Iterable, Iterator, List, Optional, Tuple, Union from typing import Any, Callable, List
import numpy as np import numpy as np
import yaml import yaml
...@@ -491,165 +491,3 @@ def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None): ...@@ -491,165 +491,3 @@ def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
among ranks in multigpu setting or only pulling a sample of documents among ranks in multigpu setting or only pulling a sample of documents
""" """
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
class Collator:
"""
A class for reordering and batching elements of an array.
This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
"""
def __init__(
self,
arr: List,
sort_fn: Callable,
group_fn: Callable = lambda x: x[1],
grouping: bool = False,
) -> None:
self.grouping = grouping
self.fn = sort_fn
self.group_fn = lambda x: group_fn(x[1]) # first index are enumerated indices
self.reorder_indices: List = []
self.size = len(arr)
self.arr_with_indices: Iterable[Any] = tuple(enumerate(arr)) # [indices, (arr)]
if self.grouping is True:
self.group_by_index()
def group_by_index(self) -> None:
self.arr_with_indices = self.group(
self.arr_with_indices, fn=self.group_fn, values=False
)
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
"""
Generates and yields batches from the reordered array.
Parameters:
- n (int): The size of each batch. Defaults to 1.
- batch_fn (Optional[Callable[[int, Iterable], int]]): A function to determine the size of each batch. Defaults to None.
Yields:
Iterator: An iterator over batches of reordered elements.
"""
if self.grouping:
for (
key,
values,
) in self.arr_with_indices.items(): # type: ignore
values = self._reorder(values)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
else:
values = self._reorder(self.arr_with_indices) # type: ignore
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> List:
"""
Reorders the elements in the array based on the sorting function.
Parameters:
- arr (Union[List, Tuple[Tuple[int, Any], ...]]): The array or iterable to be reordered.
Yields:
List: Yields reordered elements one by one.
"""
arr = sorted(arr, key=lambda x: self.fn(x[1]))
self.reorder_indices.extend([x[0] for x in arr])
yield from [x[1] for x in arr]
def get_original(self, newarr: List) -> List:
"""
Restores the original order of elements from the reordered list.
Parameters:
- newarr (List): The reordered array.
Returns:
List: The array with elements restored to their original order.
"""
res = [None] * self.size
cov = [False] * self.size
for ind, v in zip(self.reorder_indices, newarr):
res[ind] = v
cov[ind] = True
assert all(cov)
return res
def __len__(self):
return self.size
@staticmethod
def group(arr: Iterable, fn: Callable, values: bool = False) -> Iterable:
"""
Groups elements of an iterable based on a provided function.
Parameters:
- arr (Iterable): The iterable to be grouped.
- fn (Callable): The function to determine the grouping.
- values (bool): If True, returns the values of the group. Defaults to False.
Returns:
Iterable: An iterable of grouped elements.
"""
res = collections.defaultdict(list)
for ob in arr:
try:
hashable_dict = tuple(
(
key,
tuple(value)
if isinstance(value, collections.abc.Iterable)
else value,
)
for key, value in sorted(fn(ob).items())
)
res[hashable_dict].append(ob)
except TypeError:
res[fn(ob)].append(ob)
if not values:
return res
return res.values()
@staticmethod
def get_chunks(_iter, n: int = 0, fn=None):
"""
Divides an iterable into chunks of specified size or based on a given function.
Useful for batching
Parameters:
- iter: The input iterable to be divided into chunks.
- n: An integer representing the size of each chunk. Default is 0.
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
Returns:
An iterator that yields chunks of the input iterable.
Example usage:
```
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for chunk in chunks(data, 3):
print(chunk)
```
Output:
```
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10]
```
"""
arr = []
_iter = tuple(_iter)
for i, x in enumerate(_iter):
arr.append(x)
if len(arr) == (fn(i, _iter) if fn else n):
yield arr
arr = []
if arr:
yield arr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment