Commit 3c772593 authored by Baber's avatar Baber
Browse files

add attn_mask (llava models need it)

parent 4142b7b2
import copy import copy
from typing import Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -21,6 +21,9 @@ from lm_eval.models.utils import ( ...@@ -21,6 +21,9 @@ from lm_eval.models.utils import (
from lm_eval.utils import add_padding_if_needed from lm_eval.utils import add_padding_if_needed
if TYPE_CHECKING:
import PIL
DEFAULT_IMAGE_PLACEHOLDER = "<image>" DEFAULT_IMAGE_PLACEHOLDER = "<image>"
...@@ -175,7 +178,9 @@ class HFMultimodalLM(HFLM): ...@@ -175,7 +178,9 @@ class HFMultimodalLM(HFLM):
return text_encoding, encoding # image_encoding is a dict return text_encoding, encoding # image_encoding is a dict
def _encode_multimodal_pair(self, context, continuation, images): def _encode_multimodal_pair(
self, context, continuation, images: List["PIL.Image.Image"]
):
"""Helper function to perform the role of TemplateLM._encode_pair """Helper function to perform the role of TemplateLM._encode_pair
Except allowing for image input to also be processed alongside `context`. Except allowing for image input to also be processed alongside `context`.
...@@ -192,6 +197,9 @@ class HFMultimodalLM(HFLM): ...@@ -192,6 +197,9 @@ class HFMultimodalLM(HFLM):
context, DEFAULT_IMAGE_PLACEHOLDER, self.image_token, self.max_images context, DEFAULT_IMAGE_PLACEHOLDER, self.image_token, self.max_images
) )
if self.rgb:
images = [img.convert("RGB") for img in images]
whole_enc, image_enc = self.tok_multimodal_encode( whole_enc, image_enc = self.tok_multimodal_encode(
context + continuation, images context + continuation, images
) )
...@@ -346,7 +354,7 @@ class HFMultimodalLM(HFLM): ...@@ -346,7 +354,7 @@ class HFMultimodalLM(HFLM):
""" """
# note: imgs is a dict. # note: imgs is a dict.
with torch.no_grad(): with torch.no_grad():
return self.model(inps, attention_mask=torch.ones_like(inps), **imgs).logits return self.model(inps, **imgs, attention_mask=attn_mask).logits
def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs): def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs):
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
...@@ -384,7 +392,9 @@ class HFMultimodalLM(HFLM): ...@@ -384,7 +392,9 @@ class HFMultimodalLM(HFLM):
batched_imgs[key] = torch.cat( batched_imgs[key] = torch.cat(
[ [
torch.tensor( torch.tensor(
image_enc[key], device=self.device, dtype=self.model.dtype image_enc[key],
device=self.device,
dtype=self.model.dtype if key == "pixel_values" else torch.int,
) )
for image_enc in image_encs for image_enc in image_encs
], ],
...@@ -453,15 +463,16 @@ class HFMultimodalLM(HFLM): ...@@ -453,15 +463,16 @@ class HFMultimodalLM(HFLM):
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices. # speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group. # groups requests by context+continuation[:-1] and infer on one request/group.
return req[-1] + req[-3] + req[-2][:-1] return req[-3] + req[-2]
re_ord = Collator( re_ord = Collator(
requests, requests,
sort_fn=_collate, sort_fn=_collate,
group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs group_by=None,
if self.backend == "causal" and self.logits_cache # group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs
else None, # if self.backend == "causal" and self.logits_cache
group_fn=_lookup_one_token_cont, # else None,
# group_fn=_lookup_one_token_cont,
) )
# automatic (variable) batch size detection for vectorization # automatic (variable) batch size detection for vectorization
...@@ -545,7 +556,12 @@ class HFMultimodalLM(HFLM): ...@@ -545,7 +556,12 @@ class HFMultimodalLM(HFLM):
) # TODO: fix/test for bs>1 case with differently-sized imgs! ) # TODO: fix/test for bs>1 case with differently-sized imgs!
multi_logits = F.log_softmax( multi_logits = F.log_softmax(
self._model_multimodal_call(batched_inps, batched_imgs, **call_kwargs), self._model_multimodal_call(
batched_inps,
batched_imgs,
attn_mask=torch.ones_like(batched_inps),
**call_kwargs,
),
dim=-1, dim=-1,
) # [batch, padding_length (inp or cont), vocab] ) # [batch, padding_length (inp or cont), vocab]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment