Unverified Commit 82a99365 authored by Yury Sulsky's avatar Yury Sulsky Committed by GitHub
Browse files

Enable text-only evals for VLM models (#2999)

parent 9d29ef0e
...@@ -494,10 +494,6 @@ def evaluate( ...@@ -494,10 +494,6 @@ def evaluate(
raise ValueError( raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type." f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
) )
else:
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
)
# end validation check # end validation check
# Cache the limit arg. # Cache the limit arg.
......
...@@ -399,6 +399,9 @@ class HFMultimodalLM(HFLM): ...@@ -399,6 +399,9 @@ class HFMultimodalLM(HFLM):
return batched_imgs return batched_imgs
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
if requests and len(requests[0].args) < 3:
# Fall back to non-multimodal generation.
return super().loglikelihood_rolling(requests=requests)
raise NotImplementedError( raise NotImplementedError(
"model type `hf-multimodal` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks ", "model type `hf-multimodal` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks ",
"this is because we do not support measuring the loglikelihood a model assigns to an image.", "this is because we do not support measuring the loglikelihood a model assigns to an image.",
...@@ -407,6 +410,9 @@ class HFMultimodalLM(HFLM): ...@@ -407,6 +410,9 @@ class HFMultimodalLM(HFLM):
def loglikelihood( def loglikelihood(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
) -> List[Tuple[float, bool]]: ) -> List[Tuple[float, bool]]:
if requests and len(requests[0].args) < 3:
# Fall back to non-multimodal generation.
return super().loglikelihood(requests=requests, disable_tqdm=disable_tqdm)
raise NotImplementedError( raise NotImplementedError(
"'loglikelihood' requests for model type `hf-multimodal` are not yet tested. This feature will be enabled when a loglikelihood-based multiple-choice VQA dataset is added!" "'loglikelihood' requests for model type `hf-multimodal` are not yet tested. This feature will be enabled when a loglikelihood-based multiple-choice VQA dataset is added!"
) )
...@@ -433,9 +439,11 @@ class HFMultimodalLM(HFLM): ...@@ -433,9 +439,11 @@ class HFMultimodalLM(HFLM):
) )
) )
return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm) return self._multimodal_loglikelihood_tokens(
new_reqs, disable_tqdm=disable_tqdm
)
def _loglikelihood_tokens( def _multimodal_loglikelihood_tokens(
self, self,
requests: List[ requests: List[
Tuple[Tuple[None, str, str], List[int], List[int], List[int]] Tuple[Tuple[None, str, str], List[int], List[int], List[int]]
...@@ -624,7 +632,10 @@ class HFMultimodalLM(HFLM): ...@@ -624,7 +632,10 @@ class HFMultimodalLM(HFLM):
def generate_until( def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]: ) -> List[str]:
# TODO: back out to HFLM.generate_until() for all requests without aux_arguments (text-only reqs) if requests and len(requests[0].args) < 3:
# Fall back to non-multimodal generation.
return super().generate_until(requests=requests, disable_tqdm=disable_tqdm)
res = [] res = []
def _collate(x): def _collate(x):
......
...@@ -890,7 +890,10 @@ class HFLM(TemplateLM): ...@@ -890,7 +890,10 @@ class HFLM(TemplateLM):
input_ids=inps, attention_mask=attn_mask, labels=labels input_ids=inps, attention_mask=attn_mask, labels=labels
).logits ).logits
else: else:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM assert self.AUTO_MODEL_CLASS in (
transformers.AutoModelForCausalLM,
transformers.AutoModelForVision2Seq,
)
return self.model(inps).logits return self.model(inps).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
......
...@@ -106,7 +106,7 @@ class VLLM_VLM(VLLM): ...@@ -106,7 +106,7 @@ class VLLM_VLM(VLLM):
outputs.append(inputs) outputs.append(inputs)
return outputs return outputs
def _model_generate( def _multimodal_model_generate(
self, self,
requests: List[List[dict]] = None, requests: List[List[dict]] = None,
generate: bool = False, generate: bool = False,
...@@ -218,7 +218,10 @@ class VLLM_VLM(VLLM): ...@@ -218,7 +218,10 @@ class VLLM_VLM(VLLM):
def generate_until( def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]: ) -> List[str]:
# TODO: support text-only reqs if requests and len(requests[0].args) < 3:
# Fall back to non-multimodal generation.
return super().generate_until(requests=requests, disable_tqdm=disable_tqdm)
res = [] res = []
def _collate(x): def _collate(x):
...@@ -293,7 +296,7 @@ class VLLM_VLM(VLLM): ...@@ -293,7 +296,7 @@ class VLLM_VLM(VLLM):
left_truncate_len=max_ctx_len, left_truncate_len=max_ctx_len,
) )
cont = self._model_generate( cont = self._multimodal_model_generate(
inputs, stop=until, generate=True, max_tokens=max_gen_toks, **kwargs inputs, stop=until, generate=True, max_tokens=max_gen_toks, **kwargs
) )
...@@ -309,3 +312,12 @@ class VLLM_VLM(VLLM): ...@@ -309,3 +312,12 @@ class VLLM_VLM(VLLM):
pbar.close() pbar.close()
return res return res
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
if requests and len(requests[0].args) < 3:
# Fall back to non-multimodal generation.
return super().loglikelihood_rolling(requests=requests)
raise NotImplementedError(
"model type `vllm-vlm` does not support loglikelihood_rolling. Use 'vlm' model type for text-only loglikelihood_rolling tasks ",
"this is because we do not support measuring the loglikelihood a model assigns to an image.",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment