Unverified Commit afe1e465 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] fix the style for multimodal models (#2257)

parent f50a6cf4
...@@ -149,8 +149,8 @@ class ImageInputs: ...@@ -149,8 +149,8 @@ class ImageInputs:
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache. # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward, # Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid illegal # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# cuda memory access. # errors in cuda kernels. See also llava.py for example.
ret.pad_values = [x % (1 << 30) for x in ret.image_hashes] ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
optional_args = [ optional_args = [
...@@ -172,8 +172,8 @@ class ImageInputs: ...@@ -172,8 +172,8 @@ class ImageInputs:
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache. # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward, # Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid illegal # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# cuda memory access. # errors in cuda kernels. See also llava.py for example.
self.image_hashes += other.image_hashes self.image_hashes += other.image_hashes
self.pad_values = [x % (1 << 30) for x in self.image_hashes] self.pad_values = [x % (1 << 30) for x in self.image_hashes]
......
...@@ -568,15 +568,17 @@ class Scheduler: ...@@ -568,15 +568,17 @@ class Scheduler:
) )
req.extend_image_inputs(image_inputs) req.extend_image_inputs(image_inputs)
if len(req.origin_input_ids) > self.max_req_input_len: if len(req.origin_input_ids) >= self.max_req_input_len:
req.finished_reason = FINISH_ABORT( logger.error(
"Image request length is longer than the KV cache pool size or " "Multimodal prompt is too long after expanding multimodal tokens. "
"the max context length. " f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. "
"Abort this request because you cannot truncate the image embeds"
) )
req.image_inputs = None
req.origin_input_ids = [0] req.origin_input_ids = [0]
req.image_inputs = None
req.sampling_params.max_new_tokens = 0 req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT(
"Multimodal prompt is too long. Check server logs for details."
)
self.waiting_queue.append(req) self.waiting_queue.append(req)
return return
......
...@@ -134,7 +134,6 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -134,7 +134,6 @@ class LlavaBaseForCausalLM(nn.Module):
image_inputs = forward_batch.image_inputs image_inputs = forward_batch.image_inputs
if forward_batch.forward_mode.is_extend(): if forward_batch.forward_mode.is_extend():
bs = forward_batch.batch_size
# Got List[List[str]] extend it to List[str] # Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size # The length of the List should be equal to batch size
modalities_list = [] modalities_list = []
...@@ -142,7 +141,7 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -142,7 +141,7 @@ class LlavaBaseForCausalLM(nn.Module):
for im in image_inputs: for im in image_inputs:
if im and im.modalities is not None: if im and im.modalities is not None:
modalities_list.extend(im.modalities) modalities_list.extend(im.modalities)
if im and im.image_offsets is not None: if im and im.image_offsets:
max_image_offset.append(max(im.image_offsets)) max_image_offset.append(max(im.image_offsets))
else: else:
max_image_offset.append(-1) max_image_offset.append(-1)
...@@ -159,6 +158,7 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -159,6 +158,7 @@ class LlavaBaseForCausalLM(nn.Module):
need_vision = start_positions <= np.array(max_image_offset) need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any(): if need_vision.any():
bs = forward_batch.batch_size
pixel_values = [ pixel_values = [
image_inputs[i].pixel_values for i in range(bs) if need_vision[i] image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment