Unverified Commit f50a6cf4 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix hash collision for multi modal models (#2256)

parent fe97a2d4
...@@ -124,7 +124,7 @@ class FINISH_ABORT(BaseFinishReason): ...@@ -124,7 +124,7 @@ class FINISH_ABORT(BaseFinishReason):
class ImageInputs: class ImageInputs:
"""The image related inputs.""" """The image related inputs."""
pixel_values: torch.Tensor pixel_values: Union[torch.Tensor, np.array]
image_hashes: Optional[list] = None image_hashes: Optional[list] = None
image_sizes: Optional[list] = None image_sizes: Optional[list] = None
image_offsets: Optional[list] = None image_offsets: Optional[list] = None
...@@ -132,7 +132,7 @@ class ImageInputs: ...@@ -132,7 +132,7 @@ class ImageInputs:
modalities: Optional[list] = None modalities: Optional[list] = None
num_image_tokens: Optional[int] = None num_image_tokens: Optional[int] = None
image_embeds: Optional[List[torch.Tensor]] = None # Llava related
aspect_ratio_ids: Optional[List[torch.Tensor]] = None aspect_ratio_ids: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None aspect_ratio_mask: Optional[List[torch.Tensor]] = None
...@@ -141,21 +141,17 @@ class ImageInputs: ...@@ -141,21 +141,17 @@ class ImageInputs:
mrope_position_delta: Optional[torch.Tensor] = None mrope_position_delta: Optional[torch.Tensor] = None
@staticmethod @staticmethod
def from_dict(obj, vocab_size): def from_dict(obj: dict):
# Use image hash as fake token_ids, which is then used for prefix matching
ret = ImageInputs( ret = ImageInputs(
pixel_values=obj["pixel_values"], pixel_values=obj["pixel_values"],
image_hashes=obj["image_hashes"], image_hashes=obj["image_hashes"],
) )
if not isinstance(ret.image_hashes, list):
ret.pad_values = [ # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
(ret.image_hashes) % vocab_size, # Please note that if the `input_ids` is later used in the model forward,
(ret.image_hashes >> 16) % vocab_size, # you also need to clamp the values within the range of [0, vocab_size) to avoid illegal
(ret.image_hashes >> 32) % vocab_size, # cuda memory access.
(ret.image_hashes >> 64) % vocab_size, ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
]
else:
ret.pad_values = [x % vocab_size for x in ret.image_hashes]
optional_args = [ optional_args = [
"image_sizes", "image_sizes",
...@@ -170,21 +166,16 @@ class ImageInputs: ...@@ -170,21 +166,16 @@ class ImageInputs:
return ret return ret
def merge(self, other, vocab_size): def merge(self, other):
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:] assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values]) self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
if isinstance(self.image_hashes, list) and isinstance(other.image_hashes, list): # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
self.image_hashes += other.image_hashes # Please note that if the `input_ids` is later used in the model forward,
self.pad_values = [x % vocab_size for x in self.image_hashes] # you also need to clamp the values within the range of [0, vocab_size) to avoid illegal
else: # cuda memory access.
self.image_hashes = hash(tuple(self.image_hashes, other.image_hashes)) self.image_hashes += other.image_hashes
self.pad_values = [ self.pad_values = [x % (1 << 30) for x in self.image_hashes]
(self.image_hashes) % vocab_size,
(self.image_hashes >> 16) % vocab_size,
(self.image_hashes >> 32) % vocab_size,
(self.image_hashes >> 64) % vocab_size,
]
optional_args = [ optional_args = [
"image_sizes", "image_sizes",
...@@ -297,11 +288,11 @@ class Req: ...@@ -297,11 +288,11 @@ class Req:
# The number of cached tokens, that were already cached in the KV cache # The number of cached tokens, that were already cached in the KV cache
self.cached_tokens = 0 self.cached_tokens = 0
def extend_image_inputs(self, image_inputs, vocab_size): def extend_image_inputs(self, image_inputs):
if self.image_inputs is None: if self.image_inputs is None:
self.image_inputs = image_inputs self.image_inputs = image_inputs
else: else:
self.image_inputs.merge(image_inputs, vocab_size) self.image_inputs.merge(image_inputs)
# whether request reached finished condition # whether request reached finished condition
def finished(self) -> bool: def finished(self) -> bool:
......
...@@ -526,8 +526,9 @@ class Scheduler: ...@@ -526,8 +526,9 @@ class Scheduler:
self, self,
recv_req: TokenizedGenerateReqInput, recv_req: TokenizedGenerateReqInput,
): ):
# Create a new request
if recv_req.session_id is None or recv_req.session_id not in self.sessions: if recv_req.session_id is None or recv_req.session_id not in self.sessions:
# Create a new request
if recv_req.input_embeds is not None: if recv_req.input_embeds is not None:
# Generate fake input_ids based on the length of input_embeds # Generate fake input_ids based on the length of input_embeds
seq_length = len(recv_req.input_embeds) seq_length = len(recv_req.input_embeds)
...@@ -558,20 +559,20 @@ class Scheduler: ...@@ -558,20 +559,20 @@ class Scheduler:
self.waiting_queue.append(req) self.waiting_queue.append(req)
return return
# Image inputs # Handle image inputs
if recv_req.image_inputs is not None: if recv_req.image_inputs is not None:
image_inputs = ImageInputs.from_dict( image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
recv_req.image_inputs, self.model_config.vocab_size # Expand a single image token into multiple dummy tokens for receiving image embeddings
)
req.origin_input_ids = self.pad_input_ids_func( req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs req.origin_input_ids, image_inputs
) )
req.extend_image_inputs(image_inputs, self.model_config.vocab_size) req.extend_image_inputs(image_inputs)
if len(req.origin_input_ids) > self.max_req_input_len: if len(req.origin_input_ids) > self.max_req_input_len:
req.finished_reason = FINISH_ABORT( req.finished_reason = FINISH_ABORT(
"Image request length is longer than the KV cache pool size or " "Image request length is longer than the KV cache pool size or "
"the max context length aborting because you cannot truncate the image embeds" "the max context length. "
"Abort this request because you cannot truncate the image embeds"
) )
req.image_inputs = None req.image_inputs = None
req.origin_input_ids = [0] req.origin_input_ids = [0]
...@@ -579,6 +580,7 @@ class Scheduler: ...@@ -579,6 +580,7 @@ class Scheduler:
self.waiting_queue.append(req) self.waiting_queue.append(req)
return return
# Copy more attributes
req.return_logprob = recv_req.return_logprob req.return_logprob = recv_req.return_logprob
req.top_logprobs_num = recv_req.top_logprobs_num req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream req.stream = recv_req.stream
......
...@@ -10,10 +10,7 @@ ...@@ -10,10 +10,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import copy
import uuid import uuid
from dataclasses import dataclass
from typing import Optional
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
......
...@@ -216,6 +216,7 @@ class TokenizerManager: ...@@ -216,6 +216,7 @@ class TokenizerManager:
input_ids = obj.input_ids input_ids = obj.input_ids
if self.is_generation: if self.is_generation:
# TODO: also support getting embeddings for multimodal models
image_inputs: Dict = await self.image_processor.process_images_async( image_inputs: Dict = await self.image_processor.process_images_async(
obj.image_data, input_text or input_ids, obj obj.image_data, input_text or input_ids, obj
) )
......
...@@ -147,6 +147,11 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -147,6 +147,11 @@ class LlavaBaseForCausalLM(nn.Module):
else: else:
max_image_offset.append(-1) max_image_offset.append(-1)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
# Embed text inputs # Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids) input_embeds = self.language_model.model.embed_tokens(input_ids)
......
...@@ -597,13 +597,15 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -597,13 +597,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed. `None` if no images are passed.
""" """
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
positions = forward_batch.mrope_positions
image_inputs = None image_inputs = None
if forward_batch.image_inputs is not None: if forward_batch.image_inputs is not None:
image_inputs = [ image_inputs = [
img for img in forward_batch.image_inputs if img is not None img for img in forward_batch.image_inputs if img is not None
] ]
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
positions = forward_batch.mrope_positions
if ( if (
forward_batch.forward_mode.is_decode() forward_batch.forward_mode.is_decode()
or image_inputs is None or image_inputs is None
...@@ -617,6 +619,11 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -617,6 +619,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
f"(3, seq_len) positions, but got {positions.size()}" f"(3, seq_len) positions, but got {positions.size()}"
) )
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
inputs_embeds = self.model.embed_tokens(input_ids) inputs_embeds = self.model.embed_tokens(input_ids)
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment