Unverified Commit 48312e57 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Make `PlaceholderRange.get_num_embeds` a method (#34035)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent bc32444b
...@@ -48,7 +48,7 @@ def test_profiling(model_id: str, max_model_len: int): ...@@ -48,7 +48,7 @@ def test_profiling(model_id: str, max_model_len: int):
) # image start, image, image end ) # image start, image, image end
assert total_num_patches == sum( assert total_num_patches == sum(
item.get_num_embeds for item in mm_inputs["mm_placeholders"]["image"] item.get_num_embeds() for item in mm_inputs["mm_placeholders"]["image"]
) )
assert total_tokens == sum( assert total_tokens == sum(
placeholder.length for placeholder in mm_inputs["mm_placeholders"]["image"] placeholder.length for placeholder in mm_inputs["mm_placeholders"]["image"]
......
...@@ -19,7 +19,7 @@ from vllm.multimodal.inputs import PlaceholderRange ...@@ -19,7 +19,7 @@ from vllm.multimodal.inputs import PlaceholderRange
def test_placeholder_range_get_num_embeds(is_embed, expected): def test_placeholder_range_get_num_embeds(is_embed, expected):
length = len(is_embed) if is_embed is not None else 5 length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed) pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
assert pr.get_num_embeds == expected assert pr.get_num_embeds() == expected
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -187,7 +187,7 @@ def test_schedule_request_multi_images_respect_compute_limit(): ...@@ -187,7 +187,7 @@ def test_schedule_request_multi_images_respect_compute_limit():
def test_encoder_cache_with_is_embed_mask(): def test_encoder_cache_with_is_embed_mask():
class MockRequestWithMask(MockRequest): class MockRequestWithMask(MockRequest):
def get_num_encoder_embeds(self, input_id: int) -> int: def get_num_encoder_embeds(self, input_id: int) -> int:
return self.mm_features[input_id].mm_position.get_num_embeds return self.mm_features[input_id].mm_position.get_num_embeds()
is_embed = torch.zeros(100, dtype=torch.bool) is_embed = torch.zeros(100, dtype=torch.bool)
is_embed[torch.tensor([5, 15, 25, 35, 45, 55, 65, 75])] = True is_embed[torch.tensor([5, 15, 25, 35, 45, 55, 65, 75])] = True
...@@ -207,7 +207,7 @@ def test_encoder_cache_with_is_embed_mask(): ...@@ -207,7 +207,7 @@ def test_encoder_cache_with_is_embed_mask():
assert "img1" in manager.cached assert "img1" in manager.cached
old_size = 100 old_size = 100
new_size = request.mm_features[0].mm_position.get_num_embeds new_size = request.mm_features[0].mm_position.get_num_embeds()
assert new_size == 8 assert new_size == 8
savings_ratio = old_size / new_size savings_ratio = old_size / new_size
assert savings_ratio == 12.5 assert savings_ratio == 12.5
...@@ -216,7 +216,7 @@ def test_encoder_cache_with_is_embed_mask(): ...@@ -216,7 +216,7 @@ def test_encoder_cache_with_is_embed_mask():
def test_encoder_cache_mask_based_retrieval(): def test_encoder_cache_mask_based_retrieval():
class MockRequestWithMask(MockRequest): class MockRequestWithMask(MockRequest):
def get_num_encoder_embeds(self, input_id: int) -> int: def get_num_encoder_embeds(self, input_id: int) -> int:
return self.mm_features[input_id].mm_position.get_num_embeds return self.mm_features[input_id].mm_position.get_num_embeds()
is_embed = torch.tensor( is_embed = torch.tensor(
[False, False, True, True, False, True, True, True, False, False] [False, False, True, True, False, True, True, True, False, False]
...@@ -233,7 +233,7 @@ def test_encoder_cache_mask_based_retrieval(): ...@@ -233,7 +233,7 @@ def test_encoder_cache_mask_based_retrieval():
manager = EncoderCacheManager(cache_size=50) manager = EncoderCacheManager(cache_size=50)
manager.allocate(request, 0) manager.allocate(request, 0)
assert request.mm_features[0].mm_position.get_num_embeds == 5 assert request.mm_features[0].mm_position.get_num_embeds() == 5
start_idx = 2 start_idx = 2
end_idx = 8 end_idx = 8
......
...@@ -33,7 +33,7 @@ def get_mm_max_toks_per_item( ...@@ -33,7 +33,7 @@ def get_mm_max_toks_per_item(
) )
return { return {
modality: sum(item.get_num_embeds for item in placeholders) modality: sum(item.get_num_embeds() for item in placeholders)
for modality, placeholders in mm_inputs["mm_placeholders"].items() for modality, placeholders in mm_inputs["mm_placeholders"].items()
} }
......
...@@ -199,7 +199,6 @@ class PlaceholderRange: ...@@ -199,7 +199,6 @@ class PlaceholderRange:
def embeds_cumsum(self) -> torch.Tensor | None: def embeds_cumsum(self) -> torch.Tensor | None:
return None if self.is_embed is None else self.is_embed.cumsum(dim=0) return None if self.is_embed is None else self.is_embed.cumsum(dim=0)
@cached_property
def get_num_embeds(self) -> int: def get_num_embeds(self) -> int:
if self.embeds_cumsum is None: if self.embeds_cumsum is None:
return self.length return self.length
......
...@@ -1100,7 +1100,7 @@ class Scheduler(SchedulerInterface): ...@@ -1100,7 +1100,7 @@ class Scheduler(SchedulerInterface):
for i, mm_feature in enumerate(mm_features): for i, mm_feature in enumerate(mm_features):
start_pos = mm_feature.mm_position.offset start_pos = mm_feature.mm_position.offset
num_encoder_tokens = mm_feature.mm_position.length num_encoder_tokens = mm_feature.mm_position.length
num_encoder_embeds = mm_feature.mm_position.get_num_embeds num_encoder_embeds = mm_feature.mm_position.get_num_embeds()
item_identifier = mm_feature.identifier item_identifier = mm_feature.identifier
# The encoder output is needed if the two ranges overlap: # The encoder output is needed if the two ranges overlap:
......
...@@ -786,7 +786,7 @@ class InputProcessor: ...@@ -786,7 +786,7 @@ class InputProcessor:
decoder_mm_positions = prompt_inputs["mm_placeholders"] decoder_mm_positions = prompt_inputs["mm_placeholders"]
for modality, mm_positions in decoder_mm_positions.items(): for modality, mm_positions in decoder_mm_positions.items():
for mm_position in mm_positions: for mm_position in mm_positions:
embed_length = mm_position.get_num_embeds embed_length = mm_position.get_num_embeds()
if embed_length > self.mm_encoder_cache_size: if embed_length > self.mm_encoder_cache_size:
raise ValueError( raise ValueError(
f"The {prompt_type} prompt contains a(n) {modality} item " f"The {prompt_type} prompt contains a(n) {modality} item "
......
...@@ -260,7 +260,7 @@ class Request: ...@@ -260,7 +260,7 @@ class Request:
def get_num_encoder_embeds(self, input_id: int) -> int: def get_num_encoder_embeds(self, input_id: int) -> int:
assert input_id < len(self.mm_features) assert input_id < len(self.mm_features)
return self.mm_features[input_id].mm_position.get_num_embeds return self.mm_features[input_id].mm_position.get_num_embeds()
def record_event( def record_event(
self, self,
......
...@@ -2326,7 +2326,7 @@ class GPUModelRunner( ...@@ -2326,7 +2326,7 @@ class GPUModelRunner(
# Prefer pos_info.get_num_embeds to count precise MM embedding tokens. # Prefer pos_info.get_num_embeds to count precise MM embedding tokens.
num_tokens = self.model.get_num_mm_encoder_tokens( # type: ignore[attr-defined] num_tokens = self.model.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
pos_info.get_num_embeds pos_info.get_num_embeds()
) )
prompt_lora_mapping.append(lora_id) prompt_lora_mapping.append(lora_id)
token_lora_mapping.extend([lora_id] * num_tokens) token_lora_mapping.extend([lora_id] * num_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment