Commit 9a4fa699 authored by Baber's avatar Baber
Browse files

fix logprobs

parent b1746639
...@@ -44,11 +44,16 @@ class HFMultimodalLM(HFLM): ...@@ -44,11 +44,16 @@ class HFMultimodalLM(HFLM):
interleave: bool = True, interleave: bool = True,
# TODO: handle whitespace in image placeholder (replacement) # TODO: handle whitespace in image placeholder (replacement)
max_images: Optional[int] = 999, max_images: Optional[int] = 999,
convert_img_format=False, convert_img_format: bool = False,
auto_model_class: str = None,
**kwargs, **kwargs,
): ):
if auto_model_class is not None:
self.AUTO_MODEL_CLASS = getattr(transformers, auto_model_class)
# We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer # We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer
# modify init behavior. # modify init behavior.
super().__init__(pretrained, **kwargs) super().__init__(pretrained, **kwargs)
assert ( assert (
...@@ -183,7 +188,9 @@ class HFMultimodalLM(HFLM): ...@@ -183,7 +188,9 @@ class HFMultimodalLM(HFLM):
continuation = context[-n_spaces:] + continuation continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces] context = context[:-n_spaces]
# TODO: replace default <image> placeholder with self.image_token, for contexts context = replace_placeholders(
context, DEFAULT_IMAGE_PLACEHOLDER, self.image_token, self.max_images
)
whole_enc, image_enc = self.tok_multimodal_encode( whole_enc, image_enc = self.tok_multimodal_encode(
context + continuation, images context + continuation, images
...@@ -339,7 +346,7 @@ class HFMultimodalLM(HFLM): ...@@ -339,7 +346,7 @@ class HFMultimodalLM(HFLM):
""" """
# note: imgs is a dict. # note: imgs is a dict.
with torch.no_grad(): with torch.no_grad():
return self.model(inps, **imgs).logits return self.model(inps, attention_mask=torch.ones_like(inps), **imgs).logits
def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs): def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs):
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
...@@ -456,8 +463,7 @@ class HFMultimodalLM(HFLM): ...@@ -456,8 +463,7 @@ class HFMultimodalLM(HFLM):
requests, requests,
sort_fn=_collate, sort_fn=_collate,
group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM if self.backend == "causal" and self.logits_cache
and self.logits_cache
else None, else None,
group_fn=_lookup_one_token_cont, group_fn=_lookup_one_token_cont,
) )
...@@ -563,7 +569,7 @@ class HFMultimodalLM(HFLM): ...@@ -563,7 +569,7 @@ class HFMultimodalLM(HFLM):
# from prompt/prefix tuning tokens, if applicable # from prompt/prefix tuning tokens, if applicable
ctx_len = ( ctx_len = (
inplen + (logits.shape[0] - padding_len_inp) inplen + (logits.shape[0] - padding_len_inp)
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM if self.backend == "causal"
else None else None
) )
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
......
include: mathvista.yaml dataset_path: AI4Math/MathVista
task: mathvista_mcq task: mathvista_mcq
test_split: testmini
output_type: "multiple_choice" output_type: "multiple_choice"
doc_to_image:
- decoded_image
doc_to_text: "<image>{{query}}\n\nAnswer:"
process_docs: !function utils.process_docs_mcq process_docs: !function utils.process_docs_mcq
doc_to_choice: '{{ ["A", "B", "C", "D", "E", "F"][:choices.length] }}' doc_to_choice: '{{ ["A", "B", "C", "D", "E", "F", "G"][:choices|length] }}'
doc_to_target: "{{choices.index(answer)}}" doc_to_target: "{{choices.index(answer)}}"
metric_list: metric_list:
- metric: acc - metric: acc
...@@ -11,5 +15,7 @@ metric_list: ...@@ -11,5 +15,7 @@ metric_list:
- metric: acc_norm - metric: acc_norm
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
dataset_kwargs:
trust_remote_code: true
metadata: metadata:
version: 1.0 version: 1.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment