Commit 3c772593 authored by Baber's avatar Baber
Browse files

add attn_mask (llava models need it)

parent 4142b7b2
import copy
from typing import Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
......@@ -21,6 +21,9 @@ from lm_eval.models.utils import (
from lm_eval.utils import add_padding_if_needed
if TYPE_CHECKING:
import PIL
DEFAULT_IMAGE_PLACEHOLDER = "<image>"
......@@ -175,7 +178,9 @@ class HFMultimodalLM(HFLM):
return text_encoding, encoding # image_encoding is a dict
def _encode_multimodal_pair(self, context, continuation, images):
def _encode_multimodal_pair(
self, context, continuation, images: List["PIL.Image.Image"]
):
"""Helper function to perform the role of TemplateLM._encode_pair
Except allowing for image input to also be processed alongside `context`.
......@@ -192,6 +197,9 @@ class HFMultimodalLM(HFLM):
context, DEFAULT_IMAGE_PLACEHOLDER, self.image_token, self.max_images
)
if self.rgb:
images = [img.convert("RGB") for img in images]
whole_enc, image_enc = self.tok_multimodal_encode(
context + continuation, images
)
......@@ -346,7 +354,7 @@ class HFMultimodalLM(HFLM):
"""
# note: imgs is a dict.
with torch.no_grad():
return self.model(inps, attention_mask=torch.ones_like(inps), **imgs).logits
return self.model(inps, **imgs, attention_mask=attn_mask).logits
def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs):
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
......@@ -384,7 +392,9 @@ class HFMultimodalLM(HFLM):
batched_imgs[key] = torch.cat(
[
torch.tensor(
image_enc[key], device=self.device, dtype=self.model.dtype
image_enc[key],
device=self.device,
dtype=self.model.dtype if key == "pixel_values" else torch.int,
)
for image_enc in image_encs
],
......@@ -453,15 +463,16 @@ class HFMultimodalLM(HFLM):
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
return req[-1] + req[-3] + req[-2][:-1]
return req[-3] + req[-2]
re_ord = Collator(
requests,
sort_fn=_collate,
group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs
if self.backend == "causal" and self.logits_cache
else None,
group_fn=_lookup_one_token_cont,
group_by=None,
# group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs
# if self.backend == "causal" and self.logits_cache
# else None,
# group_fn=_lookup_one_token_cont,
)
# automatic (variable) batch size detection for vectorization
......@@ -545,7 +556,12 @@ class HFMultimodalLM(HFLM):
) # TODO: fix/test for bs>1 case with differently-sized imgs!
multi_logits = F.log_softmax(
self._model_multimodal_call(batched_inps, batched_imgs, **call_kwargs),
self._model_multimodal_call(
batched_inps,
batched_imgs,
attn_mask=torch.ones_like(batched_inps),
**call_kwargs,
),
dim=-1,
) # [batch, padding_length (inp or cont), vocab]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment