Unverified Commit 5ad39dd3 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

fix crash in multi-modal (#2245)



* fix crash in multi-modal
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* update according to review comment
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* fix llava_next regression in latest main
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

---------
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent a8950294
...@@ -424,7 +424,7 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -424,7 +424,7 @@ class FlashLlamaModel(torch.nn.Module):
FlashLlamaLayer( FlashLlamaLayer(
index=0, index=0,
prefix=( prefix=(
"model.layers.0" if not prefix else "{prefix}.model.layers.0" "model.layers.0" if not prefix else f"{prefix}.model.layers.0"
), ),
config=config, config=config,
weights=weights, weights=weights,
......
...@@ -832,6 +832,7 @@ class Idefics2ForConditionalGeneration(nn.Module): ...@@ -832,6 +832,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
max_s=max_s, max_s=max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
adapter_data=adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
......
...@@ -280,6 +280,7 @@ class LlavaNextForConditionalGeneration(nn.Module): ...@@ -280,6 +280,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
max_s=max_s, max_s=max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
adapter_data=adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
......
...@@ -14,6 +14,7 @@ from text_generation_server.models.flash_causal_lm import ( ...@@ -14,6 +14,7 @@ from text_generation_server.models.flash_causal_lm import (
) )
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from transformers import AutoProcessor from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -348,6 +349,7 @@ class VlmCausalLM(FlashCausalLM): ...@@ -348,6 +349,7 @@ class VlmCausalLM(FlashCausalLM):
else: else:
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment