"src/vscode:/vscode.git/clone" did not exist on "3b3fcc16ceb3305800c731e5cd65637b3b132a65"
Commit 55f8321d authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add up-to-date code (bs>1 seems to work w/ Llava-1.6-mistral: 0.34 MMMU

parent 6e3b2ea1
......@@ -1346,7 +1346,6 @@ class ConfigurableTask(Task):
deepcopy(self.config.generation_kwargs),
self.doc_to_visual,
doc,
self.config.task,
)
elif self.INPUT_TYPE == "text":
arguments = (ctx, deepcopy(self.config.generation_kwargs))
......
import copy
from typing import List, Optional, Tuple, Union
import transformers
from tqdm import tqdm
from transformers import AutoModelForVision2Seq
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM
from lm_eval.models.utils import Collator, stop_sequences_criteria
DEFAULT_IMAGE_TOKEN = "<image>"
......@@ -19,11 +19,11 @@ class HFMultimodalLM(HFLM):
An abstracted Hugging Face model class for multimodal LMs like Llava and Idefics.
"""
AUTO_MODEL_CLASS = AutoModelForVision2Seq
AUTO_MODEL_CLASS = transformers.AutoModelForVision2Seq
@property
def max_length(self):
raise NotImplementedError
# @property
# def max_length(self):
# raise NotImplementedError
@property
def tokenizer_name(self) -> str:
......@@ -194,6 +194,37 @@ class HFMultimodalLM(HFLM):
# def tok_decode(self, tokens, skip_special_tokens=True):
# return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def _model_generate(self, inputs, stop, **gen_kwargs):
# TODO: handle max_length
# gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
# if "top_p" not in gen_kwargs:
# gen_kwargs["top_p"] = None
# if "num_beams" not in gen_kwargs:
# gen_kwargs["num_beams"] = 1
stopping_criteria = stop_sequences_criteria(
self.tokenizer,
stop,
inputs["input_ids"].shape[1],
inputs["input_ids"].shape[0],
)
return self.model.generate(
**inputs,
# max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=True if gen_kwargs["temperature"] > 0 else False,
temperature=gen_kwargs["temperature"],
top_p=gen_kwargs["top_p"],
num_beams=gen_kwargs["num_beams"],
max_new_tokens=gen_kwargs["max_new_tokens"],
use_cache=True,
pad_token_id=self.tokenizer.eos_token_id,
)
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
raise NotImplementedError(
"model type `hf-multimodal` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks"
......@@ -204,14 +235,9 @@ class HFMultimodalLM(HFLM):
"model type `hf-multimodal` does not support loglikelihood or multiple choice. Use 'hf' model type for text-only loglikelihood tasks"
)
def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list
def generate_until(self, requests: List[Instance]) -> List[str]:
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
res = []
def _collate(x):
......@@ -224,47 +250,70 @@ class HFMultimodalLM(HFLM):
toks = self.tok_encode(x[0])
return -len(toks), x[0]
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests with text+image input",
)
# TODO: port auto-batch sizing into this.
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = utils.Collator(
[reg.args for reg in requests], _collate, grouping=True
re_ords = Collator(
[reg.args for reg in requests],
_collate,
group_by="gen_kwargs",
group_fn=lambda x: x[1],
)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
num_iters = (
len(requests) // self.batch_size
if len(requests) % self.batch_size == 0
else len(requests) // self.batch_size + 1
)
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
### Up to here: was identical to non-multimodal HFLM generate_until ###
for chunk in chunks:
contexts, all_gen_kwargs, doc_to_visual, doc, task = zip(
contexts, all_gen_kwargs, doc_to_visual, doc = zip(
*chunk
) # TODO: understand what is going on here. can we cut down on number of distinct things we pass around?
task = task[0]
# split = split[0]
visuals = [vis(d) for vis, d in zip(doc_to_visual, doc)]
# visuals = self.flatten(visuals)
) # TODO: can we cut down further on number of distinct things we pass around?
visuals = [
vis(d) for vis, d in zip(doc_to_visual, doc)
] # TODO: I think *fully* flattening is just wrong for bs>1 ?
### this part onward: same as HFLM ###
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
### end stuff that's entirely copied verbatim from HFLM ###
# Set default values for until and max_new_tokens
until = [self.tok_decode(self.eot_token_id)]
# Update values from gen_kwargs if present
if "until" in gen_kwargs:
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}"
)
assert (
self.batch_size_per_gpu == 1
), "Do not support batch_size_per_gpu > 1 for now"
context = contexts[0]
max_ctx_len = self.max_length - max_gen_toks # noqa: F841 # TODO: this assumes we are using a causal LM. is that always valid? shouldn't be
# if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
print(f"Prompt:\n\n{contexts}\n")
......@@ -272,37 +321,24 @@ class HFMultimodalLM(HFLM):
self.tokenizer.padding_side = "left"
inputs = self.processor(
images=visuals, text=contexts, return_tensors="pt", padding=True
).to(self._device, self.model.dtype) # TODO:
# gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1
try:
cont = self.model.generate(
**inputs,
do_sample=True if gen_kwargs["temperature"] > 0 else False,
temperature=gen_kwargs["temperature"],
top_p=gen_kwargs["top_p"],
num_beams=gen_kwargs["num_beams"],
max_new_tokens=gen_kwargs["max_new_tokens"],
use_cache=True,
pad_token_id=self.tokenizer.eos_token_id,
)
except Exception as e:
print(f"Error {e} in generating")
cont = ""
).to(
self._device, self.model.dtype
) # TODO: factor out into a tok_batch_encode bit ; truncate from left using max_ctx_len
context_enc = inputs["input_ids"]
if "max_length" not in kwargs:
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
cont = self._model_generate(inputs, stop=until, **gen_kwargs)
### essentially same as HFLM beyond this line!
cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM
# if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: # TODO: ensure this holds for VLMs
cont_toks = cont_toks[inputs["input_ids"].shape[1] :]
cont_toks = cont_toks[context_enc.shape[1] :]
s = self.tok_decode(cont_toks)
......@@ -313,22 +349,12 @@ class HFMultimodalLM(HFLM):
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0]
if "1.5" in self.pretrained:
text_outputs = s.split("ASSISTANT:")[-1].strip()
elif "mistral" in self.pretrained:
text_outputs = s.split("[/INST]")[-1].strip()
else:
text_outputs = s.split("ASSISTANT:")[-1].strip()
# if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
print("hi hi")
print(f"Generated text:\n\n{text_outputs}\n")
res.append(text_outputs)
self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), text_outputs
)
pbar.update(1)
print(f"Generated text:\n\n{s}\n")
res.append(s)
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)
......
......@@ -10,7 +10,7 @@ import os
import re
from dataclasses import asdict, is_dataclass
from itertools import islice
from typing import Any, Callable, Iterable, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, List
import numpy as np
import yaml
......@@ -491,165 +491,3 @@ def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
among ranks in multigpu setting or only pulling a sample of documents
"""
return islice(raw_iterator, rank, limit, world_size)
class Collator:
"""
A class for reordering and batching elements of an array.
This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
"""
def __init__(
self,
arr: List,
sort_fn: Callable,
group_fn: Callable = lambda x: x[1],
grouping: bool = False,
) -> None:
self.grouping = grouping
self.fn = sort_fn
self.group_fn = lambda x: group_fn(x[1]) # first index are enumerated indices
self.reorder_indices: List = []
self.size = len(arr)
self.arr_with_indices: Iterable[Any] = tuple(enumerate(arr)) # [indices, (arr)]
if self.grouping is True:
self.group_by_index()
def group_by_index(self) -> None:
self.arr_with_indices = self.group(
self.arr_with_indices, fn=self.group_fn, values=False
)
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
"""
Generates and yields batches from the reordered array.
Parameters:
- n (int): The size of each batch. Defaults to 1.
- batch_fn (Optional[Callable[[int, Iterable], int]]): A function to determine the size of each batch. Defaults to None.
Yields:
Iterator: An iterator over batches of reordered elements.
"""
if self.grouping:
for (
key,
values,
) in self.arr_with_indices.items(): # type: ignore
values = self._reorder(values)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
else:
values = self._reorder(self.arr_with_indices) # type: ignore
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> List:
"""
Reorders the elements in the array based on the sorting function.
Parameters:
- arr (Union[List, Tuple[Tuple[int, Any], ...]]): The array or iterable to be reordered.
Yields:
List: Yields reordered elements one by one.
"""
arr = sorted(arr, key=lambda x: self.fn(x[1]))
self.reorder_indices.extend([x[0] for x in arr])
yield from [x[1] for x in arr]
def get_original(self, newarr: List) -> List:
"""
Restores the original order of elements from the reordered list.
Parameters:
- newarr (List): The reordered array.
Returns:
List: The array with elements restored to their original order.
"""
res = [None] * self.size
cov = [False] * self.size
for ind, v in zip(self.reorder_indices, newarr):
res[ind] = v
cov[ind] = True
assert all(cov)
return res
def __len__(self):
return self.size
@staticmethod
def group(arr: Iterable, fn: Callable, values: bool = False) -> Iterable:
"""
Groups elements of an iterable based on a provided function.
Parameters:
- arr (Iterable): The iterable to be grouped.
- fn (Callable): The function to determine the grouping.
- values (bool): If True, returns the values of the group. Defaults to False.
Returns:
Iterable: An iterable of grouped elements.
"""
res = collections.defaultdict(list)
for ob in arr:
try:
hashable_dict = tuple(
(
key,
tuple(value)
if isinstance(value, collections.abc.Iterable)
else value,
)
for key, value in sorted(fn(ob).items())
)
res[hashable_dict].append(ob)
except TypeError:
res[fn(ob)].append(ob)
if not values:
return res
return res.values()
@staticmethod
def get_chunks(_iter, n: int = 0, fn=None):
"""
Divides an iterable into chunks of specified size or based on a given function.
Useful for batching
Parameters:
- iter: The input iterable to be divided into chunks.
- n: An integer representing the size of each chunk. Default is 0.
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
Returns:
An iterator that yields chunks of the input iterable.
Example usage:
```
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for chunk in chunks(data, 3):
print(chunk)
```
Output:
```
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10]
```
"""
arr = []
_iter = tuple(_iter)
for i, x in enumerate(_iter):
arr.append(x)
if len(arr) == (fn(i, _iter) if fn else n):
yield arr
arr = []
if arr:
yield arr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment