Commit 9a4fa699 authored by Baber's avatar Baber
Browse files

fix logprobs

parent b1746639
......@@ -44,11 +44,16 @@ class HFMultimodalLM(HFLM):
interleave: bool = True,
# TODO: handle whitespace in image placeholder (replacement)
max_images: Optional[int] = 999,
convert_img_format=False,
convert_img_format: bool = False,
auto_model_class: str = None,
**kwargs,
):
if auto_model_class is not None:
self.AUTO_MODEL_CLASS = getattr(transformers, auto_model_class)
# We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer
# modify init behavior.
super().__init__(pretrained, **kwargs)
assert (
......@@ -183,7 +188,9 @@ class HFMultimodalLM(HFLM):
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
# TODO: replace default <image> placeholder with self.image_token, for contexts
context = replace_placeholders(
context, DEFAULT_IMAGE_PLACEHOLDER, self.image_token, self.max_images
)
whole_enc, image_enc = self.tok_multimodal_encode(
context + continuation, images
......@@ -339,7 +346,7 @@ class HFMultimodalLM(HFLM):
"""
# note: imgs is a dict.
with torch.no_grad():
return self.model(inps, **imgs).logits
return self.model(inps, attention_mask=torch.ones_like(inps), **imgs).logits
def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs):
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
......@@ -456,8 +463,7 @@ class HFMultimodalLM(HFLM):
requests,
sort_fn=_collate,
group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
and self.logits_cache
if self.backend == "causal" and self.logits_cache
else None,
group_fn=_lookup_one_token_cont,
)
......@@ -563,7 +569,7 @@ class HFMultimodalLM(HFLM):
# from prompt/prefix tuning tokens, if applicable
ctx_len = (
inplen + (logits.shape[0] - padding_len_inp)
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
if self.backend == "causal"
else None
)
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
......
include: mathvista.yaml
dataset_path: AI4Math/MathVista
task: mathvista_mcq
test_split: testmini
output_type: "multiple_choice"
doc_to_image:
- decoded_image
doc_to_text: "<image>{{query}}\n\nAnswer:"
process_docs: !function utils.process_docs_mcq
doc_to_choice: '{{ ["A", "B", "C", "D", "E", "F"][:choices.length] }}'
doc_to_choice: '{{ ["A", "B", "C", "D", "E", "F", "G"][:choices|length] }}'
doc_to_target: "{{choices.index(answer)}}"
metric_list:
- metric: acc
......@@ -11,5 +15,7 @@ metric_list:
- metric: acc_norm
aggregation: mean
higher_is_better: true
dataset_kwargs:
trust_remote_code: true
metadata:
version: 1.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment