Unverified Commit 5ca085b8 authored by Xuan-Phi Nguyen's avatar Xuan-Phi Nguyen Committed by GitHub
Browse files

Better llava next. (#29850)



* Better llava next.
- Batched forward with multiple image of different sizes (number of patches).
- Support training, for cases without any image.
- Support multi-image in same sequence. e.g: ["<image> <image> the first image is a dog while the second is a cat", "<image> <image> <image> <image> these 4 image are..."]

Current limitation:
- Haven't done testing
- Only support right padding (for training)
- left padding (batched generation) is not ready yet.
- PR not ready.

* fix bugs in batched generation

* add tests

* fix batch-gen bugs, left-padding positions and incorrect attention mask

* remove better modeling llava

* fix formatting

* fix test

* fix test

* fix testing

* fix test

* fix formatting

* Update src/transformers/models/llava_next/modeling_llava_next.py

add clarity
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update modeling_llava_next.py

remove assert

* fix bug modeling_llava_next.py

* update modeling

* fix bugs

* fix format

* fix error

* fix new_token_positions

* Update modeling_llava_next.py

* update formatting

* add args

* removecomments

* add slow tests for batched inference

* failing tf/flax tests

* this one ic correct

* Update src/transformers/models/llava_next/modeling_llava_next.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix docs

* make fixup

* more fixup

* add test for batch equivalence

* Update tests/models/llava_next/test_modeling_llava_next.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/image_processing_llava_next.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/image_processing_llava_next.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/modeling_llava_next.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/modeling_llava_next.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/modeling_llava_next.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/modeling_llava_next.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/modeling_llava_next.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/llava_next/modeling_llava_next.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* pr comments

* hardcode padding side for bs=1

* update

* [run-slow] llava_next

* [run-slow] llava_next

* make fix-copies

---------
Co-authored-by: default avatarNGUYEN, Xuan Phi <x.nguyen@alibaba-inc.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarraushan <raushan@huggingface.co>
Co-authored-by: default avatarRaushan Turganbay <raushan.turganbay@alumni.nu.edu.kz>
parent bdfefbad
...@@ -15,12 +15,13 @@ ...@@ -15,12 +15,13 @@
"""Image processor class for LLaVa-NeXT.""" """Image processor class for LLaVa-NeXT."""
import math import math
from typing import Dict, List, Optional, Union from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution
from ...image_transforms import ( from ...image_transforms import (
PaddingMode,
convert_to_rgb, convert_to_rgb,
get_resize_output_image_size, get_resize_output_image_size,
pad, pad,
...@@ -154,6 +155,9 @@ class LlavaNextImageProcessor(BaseImageProcessor): ...@@ -154,6 +155,9 @@ class LlavaNextImageProcessor(BaseImageProcessor):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad the image. If `True` will pad the images in the batch to the largest image in the batch
and create a pixel mask. Padding will be applied to the bottom and right of the image with zeros.
do_convert_rgb (`bool`, *optional*, defaults to `True`): do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB. Whether to convert the image to RGB.
""" """
...@@ -173,6 +177,7 @@ class LlavaNextImageProcessor(BaseImageProcessor): ...@@ -173,6 +177,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
do_normalize: bool = True, do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = True,
do_convert_rgb: bool = True, do_convert_rgb: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
...@@ -251,6 +256,74 @@ class LlavaNextImageProcessor(BaseImageProcessor): ...@@ -251,6 +256,74 @@ class LlavaNextImageProcessor(BaseImageProcessor):
**kwargs, **kwargs,
) )
def pad(
self,
image: np.ndarray,
padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
mode: PaddingMode = PaddingMode.CONSTANT,
constant_values: Union[float, Iterable[float]] = 0.0,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
as input.
Args:
image (`np.ndarray`):
The image to pad.
padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`):
Padding to apply to the edges of the height, width axes. Can be one of three formats:
- `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
- `((before, after),)` yields same before and after pad for height and width.
- `(pad,)` or int is a shortcut for before = after = pad width for all axes.
mode (`PaddingMode`):
The padding mode to use. Can be one of:
- `"constant"`: pads with a constant value.
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
vector along each axis.
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
constant_values (`float` or `Iterable[float]`, *optional*):
The value to use for the padding if `mode` is `"constant"`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
Returns:
`np.ndarray`: The padded image.
"""
# call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
if isinstance(padding, int) or len(padding) != 4:
return pad(image, padding, mode, constant_values, data_format, input_data_format)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
if mode == PaddingMode.CONSTANT:
image = np.pad(image, padding, mode="constant", constant_values=constant_values)
elif mode == PaddingMode.REFLECT:
image = np.pad(image, padding, mode="reflect")
elif mode == PaddingMode.REPLICATE:
image = np.pad(image, padding, mode="edge")
elif mode == PaddingMode.SYMMETRIC:
image = np.pad(image, padding, mode="symmetric")
else:
raise ValueError(f"Invalid padding mode: {mode}")
image = (
to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
)
return image
def _preprocess( def _preprocess(
self, self,
images: ImageInput, images: ImageInput,
...@@ -378,7 +451,7 @@ class LlavaNextImageProcessor(BaseImageProcessor): ...@@ -378,7 +451,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
paste_x = (target_width - new_width) // 2 paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2 paste_y = (target_height - new_height) // 2
padded_image = pad(image, padding=((paste_y, paste_y), (paste_x, paste_x))) padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
return padded_image return padded_image
...@@ -446,6 +519,45 @@ class LlavaNextImageProcessor(BaseImageProcessor): ...@@ -446,6 +519,45 @@ class LlavaNextImageProcessor(BaseImageProcessor):
return image_patches return image_patches
def _pad_for_batching(
self,
pixel_values: List[np.ndarray],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
Args:
pixel_values (`List[np.ndarray]`):
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
Returns:
List[`np.ndarray`]: The padded images.
"""
max_patch = max(len(x) for x in pixel_values)
pixel_values = [
self.pad(
image,
padding=((0, max_patch - image.shape[0]), (0, 0), (0, 0), (0, 0)),
data_format=data_format,
input_data_format=input_data_format,
)
for image in pixel_values
]
return pixel_values
def preprocess( def preprocess(
self, self,
images: ImageInput, images: ImageInput,
...@@ -460,6 +572,7 @@ class LlavaNextImageProcessor(BaseImageProcessor): ...@@ -460,6 +572,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
do_normalize: bool = None, do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = True,
do_convert_rgb: bool = None, do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
...@@ -496,6 +609,9 @@ class LlavaNextImageProcessor(BaseImageProcessor): ...@@ -496,6 +609,9 @@ class LlavaNextImageProcessor(BaseImageProcessor):
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`. `True`.
do_pad (`bool`, *optional*, defaults to self.do_pad):
Whether to pad the image. If `True` will pad the images in the batch to the largest image in the batch
and create a pixel mask. Padding will be applied to the bottom and right of the image with zeros.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB. Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*): return_tensors (`str` or `TensorType`, *optional*):
...@@ -516,6 +632,7 @@ class LlavaNextImageProcessor(BaseImageProcessor): ...@@ -516,6 +632,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -603,6 +720,9 @@ class LlavaNextImageProcessor(BaseImageProcessor): ...@@ -603,6 +720,9 @@ class LlavaNextImageProcessor(BaseImageProcessor):
pixel_values = np.array(pixel_values) pixel_values = np.array(pixel_values)
new_images.append(pixel_values) new_images.append(pixel_values)
data = {"pixel_values": new_images, "image_sizes": image_sizes} if do_pad:
processed_images = self._pad_for_batching(new_images)
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(
data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
)
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
Processor class for LLaVa-NeXT. Processor class for LLaVa-NeXT.
""" """
from typing import List, Optional, Union from typing import List, Optional, Union
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
...@@ -53,7 +52,8 @@ class LlavaNextProcessor(ProcessorMixin): ...@@ -53,7 +52,8 @@ class LlavaNextProcessor(ProcessorMixin):
images: ImageInput = None, images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False, padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None, truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None, max_length: Optional[int] = None,
do_pad: Optional[bool] = True,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
) -> BatchFeature: ) -> BatchFeature:
""" """
...@@ -82,6 +82,9 @@ class LlavaNextProcessor(ProcessorMixin): ...@@ -82,6 +82,9 @@ class LlavaNextProcessor(ProcessorMixin):
lengths). lengths).
max_length (`int`, *optional*): max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above). Maximum length of the returned list and optionally padding length (see above).
do_pad (`bool`, *optional*, defaults to self.do_pad):
Whether to pad the image. If `True` will pad the images in the batch to the largest image in the batch
and create a pixel mask. Padding will be applied to the bottom and right of the image with zeros.
truncation (`bool`, *optional*): truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`. Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*): return_tensors (`str` or [`~utils.TensorType`], *optional*):
...@@ -102,7 +105,7 @@ class LlavaNextProcessor(ProcessorMixin): ...@@ -102,7 +105,7 @@ class LlavaNextProcessor(ProcessorMixin):
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
""" """
if images is not None: if images is not None:
image_inputs = self.image_processor(images, return_tensors=return_tensors) image_inputs = self.image_processor(images, do_pad=do_pad, return_tensors=return_tensors)
else: else:
image_inputs = {} image_inputs = {}
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Llava-NeXT model. """ """Testing suite for the PyTorch Llava-NeXT model."""
import gc import gc
import unittest import unittest
...@@ -46,6 +46,8 @@ from ...test_modeling_common import ( ...@@ -46,6 +46,8 @@ from ...test_modeling_common import (
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
else: else:
is_torch_greater_or_equal_than_2_0 = False is_torch_greater_or_equal_than_2_0 = False
...@@ -121,7 +123,7 @@ class LlavaNextVisionText2TextModelTester: ...@@ -121,7 +123,7 @@ class LlavaNextVisionText2TextModelTester:
self.batch_size = 3 self.batch_size = 3
self.num_channels = 3 self.num_channels = 3
self.image_size = 30 self.image_size = 30
self.encoder_seq_length = 342 self.encoder_seq_length = 341
self.image_grid_pinpoints = [[32, 32]] self.image_grid_pinpoints = [[32, 32]]
def get_config(self): def get_config(self):
...@@ -153,10 +155,15 @@ class LlavaNextVisionText2TextModelTester: ...@@ -153,10 +155,15 @@ class LlavaNextVisionText2TextModelTester:
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
attention_mask = input_ids.ne(1).to(torch_device) # make attention mask left-padded to avoid issues with "model has no attribute padding_side"
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
attention_mask[:, :1] = 0
# we are giving 3 images let's make sure we pass in 3 image tokens # we are giving 3 images let's make sure we pass in 3 image tokens
input_ids[:, 1] = config.image_token_index input_ids[:, 1] = config.image_token_index
labels = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
# maskout where the image token is
labels[:, 1] == self.ignore_index
inputs_dict = { inputs_dict = {
"pixel_values": pixel_values, "pixel_values": pixel_values,
"image_sizes": torch.tensor( "image_sizes": torch.tensor(
...@@ -164,6 +171,7 @@ class LlavaNextVisionText2TextModelTester: ...@@ -164,6 +171,7 @@ class LlavaNextVisionText2TextModelTester:
), ),
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"labels": labels,
} }
return config, inputs_dict return config, inputs_dict
...@@ -341,10 +349,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -341,10 +349,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
padding=True, padding=True,
).to(torch_device) ).to(torch_device)
# make sure image_sizes are the same # it should not matter whether two images are the same size or not
# as otherwise batched generation doesn't work
inputs.image_sizes[1] = inputs.image_sizes[0]
output = model.generate(**inputs, max_new_tokens=20) output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip
...@@ -378,3 +383,85 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -378,3 +383,85 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
self.processor.decode(output[0], skip_special_tokens=True), self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT, EXPECTED_DECODED_TEXT,
) )
@slow
@require_bitsandbytes
def test_small_model_integration_test_batch_different_resolutions(self):
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
load_in_4bit=True,
)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e"
cats_image = Image.open(requests.get(url, stream=True).raw)
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
inputs = self.processor(
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
).to(torch_device)
pixel_values = inputs["pixel_values"]
# verify pixel values are padded correctly with 0 when one image has more num_patches than the other
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=model.config.image_grid_pinpoints,
patch_size=model.config.vision_config.image_size,
)
for imsize in inputs["image_sizes"]
]
for pix_val, num_patch in zip(pixel_values, image_num_patches):
self.assertTrue(torch.all(pix_val[num_patch:] == 0)) # pad on the right
for i in range(num_patch):
self.assertFalse(torch.all(pix_val[i : i + 1] == 0)) # no padding expected in any of patches
# check loss when labels are passed
inputs["labels"] = inputs["input_ids"].clone()
with torch.no_grad():
output = model(**inputs)
expected_slice = torch.tensor(
[[-0.0308, -0.0313, -0.0314], [-0.3064, -0.3013, -0.2986], [-0.1226, -0.1246, -0.1210]],
dtype=torch.float32,
device=torch_device,
)
assert torch.allclose(output.logits[0, -3:, -3:], expected_slice, atol=1e-3)
assert torch.allclose(output.loss, torch.tensor(6.8619, device=torch_device))
# verify generation
output = model.generate(**inputs, max_new_tokens=50)
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image shows a forested area with a misty or foggy atmosphere. In the foreground, there is a grassy field with a few deer grazing. The deer are partially obscured by the fog, and the trees in the background' # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_bitsandbytes
def test_small_model_integration_test_batch_matches_single(self):
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
load_in_4bit=True,
)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e"
cats_image = Image.open(requests.get(url, stream=True).raw)
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
inputs_batched = self.processor(
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
).to(torch_device)
inputs_single = self.processor(self.prompt, images=lowres_img, return_tensors="pt", padding=True).to(
torch_device
)
# verify generation
output_batched = model.generate(**inputs_batched, max_new_tokens=50)
output_single = model.generate(**inputs_single, max_new_tokens=50)
self.assertEqual(
self.processor.decode(output_batched[0], skip_special_tokens=True),
self.processor.decode(output_single[0], skip_special_tokens=True),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment