Unverified Commit 92b0ce2a authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Bugfix][v1] fixed llava-hf/llava-1.5-7b-hf is broken on V1 (#14554)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent bc2d4473
......@@ -783,15 +783,19 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
if kwargs.get("v0_path", False):
if kwargs.get("v0_path", False) or \
image_input.get("feat_is_patch") is None or \
image_input.get("embed_is_patch") is None:
# The path is used for pixtral (V0 only) and llava (V0/V1)
return vision_embeddings
else:
nested_emb = [
self._get_mm_embeds(*args) for args in zip(
vision_embeddings, image_input["feat_is_patch"],
image_input["num_crops"], image_input["embed_is_patch"])
]
return flatten_2d_lists(nested_emb)
nested_emb = [
self._get_mm_embeds(*args) for args in zip(
vision_embeddings, image_input["feat_is_patch"],
image_input["num_crops"], image_input["embed_is_patch"])
]
return flatten_2d_lists(nested_emb)
def get_input_embeddings(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment