Unverified Commit ff647398 authored by Yan Ma's avatar Yan Ma Committed by GitHub
Browse files

[Bugfix][Model] fix mllama multi-image (#14883)


Signed-off-by: default avataryan ma <yan.ma@intel.com>
parent a164aea3
......@@ -212,7 +212,7 @@ def _run_test(
with vllm_runner(model,
dtype=dtype,
max_model_len=4096,
max_num_seqs=2,
max_num_seqs=3,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
......
......@@ -1235,11 +1235,34 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def unpack_data(self,
image_data: Union[List[torch.Tensor], torch.Tensor],
padding_value=0) -> torch.Tensor:
if isinstance(image_data, torch.Tensor):
# torch.Tensor
return image_data
else:
assert isinstance(
image_data[0],
torch.Tensor), "Image data is not properly batched."
# List[torch.Tensor]
bsz = len(image_data)
max_length = max(t.size(0) for t in image_data)
trailing_dims = image_data[0].shape[1:]
for data in image_data:
cur_trailing_dims = data.shape[1:]
assert cur_trailing_dims == trailing_dims
output_tensor = torch.full((bsz, max_length, *trailing_dims),
padding_value,
dtype=image_data[0].dtype,
device=image_data[0].device)
for i, t in enumerate(image_data):
output_tensor[i, :t.size(0)] = t
return output_tensor
def _parse_and_validate_image_input(self, **kwargs: object):
# tensor with the same shape will be batched together by
# MultiModalKwargs.batch, so pixel_values here can be:
# - List[List[torch.Tensor]]:
# with shape (num_tiles, 3, image_res, image_res)
# - List[torch.Tensor]:
# with shape (num_image, num_tiles, 3, image_res, image_res)
# - torch.Tensor:
......@@ -1274,10 +1297,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
return MllamaImagePixelInputs(
type="pixel_values",
data=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
)
data=self.unpack_data(pixel_values),
aspect_ratio_ids=self.unpack_data(aspect_ratio_ids),
aspect_ratio_mask=self.unpack_data(aspect_ratio_mask))
if image_embeds is not None:
raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment