Unverified Commit f50a6cf4 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix hash collision for multi modal models (#2256)

parent fe97a2d4
......@@ -124,7 +124,7 @@ class FINISH_ABORT(BaseFinishReason):
class ImageInputs:
"""The image related inputs."""
pixel_values: torch.Tensor
pixel_values: Union[torch.Tensor, np.array]
image_hashes: Optional[list] = None
image_sizes: Optional[list] = None
image_offsets: Optional[list] = None
......@@ -132,7 +132,7 @@ class ImageInputs:
modalities: Optional[list] = None
num_image_tokens: Optional[int] = None
image_embeds: Optional[List[torch.Tensor]] = None
# Llava related
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
......@@ -141,21 +141,17 @@ class ImageInputs:
mrope_position_delta: Optional[torch.Tensor] = None
@staticmethod
def from_dict(obj, vocab_size):
# Use image hash as fake token_ids, which is then used for prefix matching
def from_dict(obj: dict):
ret = ImageInputs(
pixel_values=obj["pixel_values"],
image_hashes=obj["image_hashes"],
)
if not isinstance(ret.image_hashes, list):
ret.pad_values = [
(ret.image_hashes) % vocab_size,
(ret.image_hashes >> 16) % vocab_size,
(ret.image_hashes >> 32) % vocab_size,
(ret.image_hashes >> 64) % vocab_size,
]
else:
ret.pad_values = [x % vocab_size for x in ret.image_hashes]
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid illegal
# cuda memory access.
ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
optional_args = [
"image_sizes",
......@@ -170,21 +166,16 @@ class ImageInputs:
return ret
def merge(self, other, vocab_size):
def merge(self, other):
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
if isinstance(self.image_hashes, list) and isinstance(other.image_hashes, list):
self.image_hashes += other.image_hashes
self.pad_values = [x % vocab_size for x in self.image_hashes]
else:
self.image_hashes = hash(tuple(self.image_hashes, other.image_hashes))
self.pad_values = [
(self.image_hashes) % vocab_size,
(self.image_hashes >> 16) % vocab_size,
(self.image_hashes >> 32) % vocab_size,
(self.image_hashes >> 64) % vocab_size,
]
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid illegal
# cuda memory access.
self.image_hashes += other.image_hashes
self.pad_values = [x % (1 << 30) for x in self.image_hashes]
optional_args = [
"image_sizes",
......@@ -297,11 +288,11 @@ class Req:
# The number of cached tokens, that were already cached in the KV cache
self.cached_tokens = 0
def extend_image_inputs(self, image_inputs, vocab_size):
def extend_image_inputs(self, image_inputs):
if self.image_inputs is None:
self.image_inputs = image_inputs
else:
self.image_inputs.merge(image_inputs, vocab_size)
self.image_inputs.merge(image_inputs)
# whether request reached finished condition
def finished(self) -> bool:
......
......@@ -526,8 +526,9 @@ class Scheduler:
self,
recv_req: TokenizedGenerateReqInput,
):
# Create a new request
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
# Create a new request
if recv_req.input_embeds is not None:
# Generate fake input_ids based on the length of input_embeds
seq_length = len(recv_req.input_embeds)
......@@ -558,20 +559,20 @@ class Scheduler:
self.waiting_queue.append(req)
return
# Image inputs
# Handle image inputs
if recv_req.image_inputs is not None:
image_inputs = ImageInputs.from_dict(
recv_req.image_inputs, self.model_config.vocab_size
)
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs
)
req.extend_image_inputs(image_inputs, self.model_config.vocab_size)
req.extend_image_inputs(image_inputs)
if len(req.origin_input_ids) > self.max_req_input_len:
req.finished_reason = FINISH_ABORT(
"Image request length is longer than the KV cache pool size or "
"the max context length aborting because you cannot truncate the image embeds"
"the max context length. "
"Abort this request because you cannot truncate the image embeds"
)
req.image_inputs = None
req.origin_input_ids = [0]
......@@ -579,6 +580,7 @@ class Scheduler:
self.waiting_queue.append(req)
return
# Copy more attributes
req.return_logprob = recv_req.return_logprob
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
......
......@@ -10,10 +10,7 @@
# limitations under the License.
# ==============================================================================
import copy
import uuid
from dataclasses import dataclass
from typing import Optional
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
......
......@@ -216,6 +216,7 @@ class TokenizerManager:
input_ids = obj.input_ids
if self.is_generation:
# TODO: also support getting embeddings for multimodal models
image_inputs: Dict = await self.image_processor.process_images_async(
obj.image_data, input_text or input_ids, obj
)
......
......@@ -147,6 +147,11 @@ class LlavaBaseForCausalLM(nn.Module):
else:
max_image_offset.append(-1)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)
......
......@@ -597,13 +597,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
"""
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
positions = forward_batch.mrope_positions
image_inputs = None
if forward_batch.image_inputs is not None:
image_inputs = [
img for img in forward_batch.image_inputs if img is not None
]
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
positions = forward_batch.mrope_positions
if (
forward_batch.forward_mode.is_decode()
or image_inputs is None
......@@ -617,6 +619,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
f"(3, seq_len) positions, but got {positions.size()}"
)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
inputs_embeds = self.model.embed_tokens(input_ids)
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment