Unverified Commit 27f4c2fd authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Separate out `RendererConfig` from `ModelConfig` (#30145)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent a49d813f
......@@ -22,7 +22,7 @@ Declare supported languages and capabilities:
import torch
from torch import nn
from vllm.config import ModelConfig, SpeechToTextConfig
from vllm.config import RendererConfig, SpeechToTextConfig
from vllm.inputs.data import PromptType
from vllm.model_executor.models.interfaces import SupportsTranscription
......@@ -52,7 +52,7 @@ This is for controlling general behavior of the API when serving your model:
@classmethod
def get_speech_to_text_config(
cls,
model_config: ModelConfig,
renderer_config: RendererConfig,
task_type: Literal["transcribe", "translate"],
) -> SpeechToTextConfig:
return SpeechToTextConfig(
......@@ -83,7 +83,7 @@ Return a dict containing `multi_modal_data` with the audio, and either a `prompt
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
renderer_config: RendererConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
......@@ -120,7 +120,7 @@ Return a dict with separate `encoder_prompt` and `decoder_prompt` entries:
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
renderer_config: RendererConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
......@@ -183,7 +183,7 @@ Provide a fast duration→token estimate to improve streaming usage statistics:
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
renderer_config: RendererConfig,
) -> int | None:
# Return None if unknown; otherwise return an estimate.
return int(audio_duration_s * stt_config.sample_rate // 320) # example
......@@ -216,7 +216,7 @@ Relevant server logic:
prompt = self.model_cls.get_generation_prompt(
audio=chunk,
stt_config=self.asr_config,
model_config=self.model_config,
renderer_config=self.renderer_config,
language=language,
task_type=self.task_type,
request_prompt=request.prompt,
......
......@@ -17,6 +17,7 @@ from vllm.config import (
DeviceConfig,
ModelConfig,
PassConfig,
RendererConfig,
VllmConfig,
get_current_vllm_config,
set_current_vllm_config,
......@@ -276,6 +277,7 @@ def sequence_parallelism_pass_on_test_model(
vllm_config = VllmConfig(
model_config=model_config,
renderer_config=RendererConfig(model_config=model_config),
device_config=device_config,
compilation_config=compilation_config,
)
......
......@@ -15,6 +15,7 @@ from vllm.config import (
CompilationConfig,
ModelConfig,
PassConfig,
RendererConfig,
VllmConfig,
set_current_vllm_config,
)
......@@ -219,8 +220,11 @@ def test_fix_functionalization(
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
model_config = ModelConfig(dtype=dtype)
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
model_config=model_config,
renderer_config=RendererConfig(model_config=model_config),
compilation_config=CompilationConfig(
custom_ops=["all"],
pass_config=PassConfig(
......
......@@ -15,6 +15,7 @@ from vllm.config import (
CompilationMode,
ModelConfig,
PassConfig,
RendererConfig,
VllmConfig,
)
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -154,8 +155,11 @@ def test_fusion_rmsnorm_quant(
custom_ops.append("+rms_norm")
if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8")
model_config = ModelConfig(dtype=dtype)
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
model_config=model_config,
renderer_config=RendererConfig(model_config=model_config),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
......
......@@ -24,6 +24,7 @@ from vllm.config import (
CompilationMode,
ModelConfig,
PassConfig,
RendererConfig,
SchedulerConfig,
VllmConfig,
set_current_vllm_config,
......@@ -325,6 +326,7 @@ def test_attention_quant_pattern(
)
vllm_config = VllmConfig(
model_config=model_config,
renderer_config=RendererConfig(model_config=model_config),
scheduler_config=SchedulerConfig(
max_num_seqs=1024,
max_model_len=model_config.max_model_len,
......
......@@ -7,7 +7,7 @@ import torch
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import ModelConfig, VllmConfig
from vllm.config import ModelConfig, RendererConfig, VllmConfig
# dummy custom pass that doesn't inherit
......@@ -43,7 +43,11 @@ class ProperPass(InductorPass):
)
def test_pass_manager_uuid(callable):
# Some passes need dtype to be set
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
model_config = ModelConfig(dtype=torch.bfloat16)
config = VllmConfig(
model_config=model_config,
renderer_config=RendererConfig(model_config=model_config),
)
pass_manager = PostGradPassManager()
pass_manager.configure(config)
......
......@@ -19,6 +19,7 @@ from vllm.config import (
CompilationMode,
ModelConfig,
PassConfig,
RendererConfig,
VllmConfig,
set_current_vllm_config,
)
......@@ -133,8 +134,10 @@ def test_qk_norm_rope_fusion(
if enable_rope_custom_op:
custom_ops.append("+rotary_embedding")
model_config = ModelConfig(dtype=dtype)
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
model_config=model_config,
renderer_config=RendererConfig(model_config=model_config),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
......
......@@ -5,6 +5,7 @@ from vllm.config import (
DeviceConfig,
KVTransferConfig,
ModelConfig,
RendererConfig,
VllmConfig,
set_current_vllm_config,
)
......@@ -47,6 +48,7 @@ def test_get_kv_connector_cache_layout_with_nixl_connector():
vllm_config = VllmConfig(
device_config=DeviceConfig("cpu"),
model_config=model_config,
renderer_config=RendererConfig(model_config=model_config),
kv_transfer_config=kv_transfer_config,
)
with set_current_vllm_config(vllm_config):
......@@ -70,6 +72,7 @@ def test_get_kv_connector_cache_layout_with_multi_connector():
vllm_config = VllmConfig(
device_config=DeviceConfig("cpu"),
model_config=model_config,
renderer_config=RendererConfig(model_config=model_config),
kv_transfer_config=kv_transfer_config,
)
with set_current_vllm_config(vllm_config):
......
......@@ -3,7 +3,6 @@
import pytest
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import apply_hf_chat_template, load_chat_template
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.tokenizers import get_tokenizer
......@@ -107,24 +106,11 @@ def test_get_gen_prompt(
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
revision=model_info.revision,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
renderer_config = model_info.build_renderer_config(model)
# Initialize the tokenizer
tokenizer = get_tokenizer(
tokenizer_name=model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
renderer_config.tokenizer,
trust_remote_code=renderer_config.trust_remote_code,
)
template_content = load_chat_template(chat_template=template)
......@@ -143,7 +129,7 @@ def test_get_gen_prompt(
tokenizer=tokenizer,
conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content,
model_config=model_config,
renderer_config=renderer_config,
tools=None,
add_generation_prompt=mock_request.add_generation_prompt,
continue_final_message=mock_request.continue_final_message,
......
......@@ -33,26 +33,34 @@ class MockModelConfig:
"""Minimal mock ModelConfig for testing."""
model: str = MODEL_NAME
tokenizer: str = MODEL_NAME
trust_remote_code: bool = False
tokenizer_mode: str = "auto"
max_model_len: int = 100
tokenizer_revision: str | None = None
multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig)
hf_config: MockHFConfig = field(default_factory=MockHFConfig)
logits_processors: list[str] | None = None
logits_processor_pattern: str | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
skip_tokenizer_init: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
@dataclass
class MockRendererConfig:
"""Minimal mock RendererConfig for testing."""
model_config: MockModelConfig
tokenizer: str = MODEL_NAME
tokenizer_mode: str = "auto"
tokenizer_revision: str | None = None
skip_tokenizer_init: bool = False
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
class MockLoRAResolver(LoRAResolver):
async def resolve_lora(
self, base_model_name: str, lora_name: str
......@@ -114,6 +122,7 @@ def mock_serving_setup():
mock_engine.add_lora.reset_mock()
mock_engine.model_config = MockModelConfig()
mock_engine.renderer_config = MockRendererConfig(mock_engine.model_config)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
......
......@@ -346,27 +346,33 @@ class MockHFConfig:
class MockModelConfig:
task = "generate"
runner_type = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processors: list[str] | None = None
logits_processor_pattern = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
@dataclass
class MockRendererConfig:
model_config: MockModelConfig = field(default_factory=MockModelConfig)
tokenizer = MODEL_NAME
tokenizer_mode = "auto"
tokenizer_revision = None
skip_tokenizer_init = False
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
models = OpenAIServingModels(
engine_client=engine,
......@@ -399,6 +405,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
@dataclass
class MockEngine:
model_config: MockModelConfig = field(default_factory=MockModelConfig)
renderer_config: MockRendererConfig = field(default_factory=MockRendererConfig)
input_processor: MagicMock = field(default_factory=MagicMock)
io_processor: MagicMock = field(default_factory=MagicMock)
......@@ -429,6 +436,7 @@ async def test_serving_chat_returns_correct_model_name():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.renderer_config = MockRendererConfig(mock_engine.model_config)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
......@@ -459,6 +467,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.renderer_config = MockRendererConfig(mock_engine.model_config)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
......@@ -492,6 +501,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.renderer_config = MockRendererConfig(mock_model_config)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
......@@ -537,6 +547,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.renderer_config = MockRendererConfig(mock_model_config)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
......@@ -583,6 +594,7 @@ async def test_serving_chat_could_load_correct_generation_config():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.renderer_config = MockRendererConfig(mock_model_config)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
......@@ -629,6 +641,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.renderer_config = MockRendererConfig(mock_model_config)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
......@@ -662,6 +675,7 @@ async def test_serving_chat_data_parallel_rank_extraction():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.renderer_config = MockRendererConfig(mock_engine.model_config)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
......
......@@ -7,7 +7,7 @@ from unittest.mock import Mock
import pytest
from vllm.config import ModelConfig
from vllm.config import ModelConfig, RendererConfig
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.tokenizers import MistralTokenizer
......@@ -19,10 +19,16 @@ def serving() -> OpenAIServing:
# Create minimal mocks
engine_client = Mock()
model_config = Mock(spec=ModelConfig)
model_config.max_model_len = 32768
renderer_config = Mock(spec=RendererConfig)
renderer_config.model_config = model_config
models = Mock(spec=OpenAIServingModels)
models.model_config = model_config
models.renderer_config = renderer_config
models.input_processor = Mock()
models.io_processor = Mock()
......
......@@ -6,7 +6,7 @@ from unittest.mock import MagicMock
import pytest
from vllm.config import ModelConfig
from vllm.config import ModelConfig, RendererConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (
ErrorResponse,
......@@ -27,9 +27,15 @@ LORA_UNLOADING_SUCCESS_MESSAGE = (
async def _async_serving_models_init() -> OpenAIServingModels:
mock_engine_client = MagicMock(spec=EngineClient)
# Set the max_model_len attribute to avoid missing attribute
mock_model_config = MagicMock(spec=ModelConfig)
mock_model_config.max_model_len = 2048
mock_renderer_config = MagicMock(spec=RendererConfig)
mock_renderer_config.model_config = mock_model_config
mock_engine_client.model_config = mock_model_config
mock_engine_client.renderer_config = mock_renderer_config
mock_engine_client.input_processor = MagicMock()
mock_engine_client.io_processor = MagicMock()
......
......@@ -12,7 +12,7 @@ from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import ModelConfig
from vllm.config import ModelConfig, RendererConfig
from vllm.entrypoints.chat_utils import (
_try_extract_ast,
apply_mistral_chat_template,
......@@ -233,7 +233,7 @@ def test_parse_chat_messages_single_image(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -265,7 +265,7 @@ def test_parse_chat_messages_single_image_with_uuid(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -295,7 +295,7 @@ def test_parse_chat_messages_single_empty_image_with_uuid(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -328,7 +328,7 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -369,7 +369,7 @@ def test_parse_chat_messages_multiple_images_with_uuids(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -409,7 +409,7 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -451,7 +451,7 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -485,7 +485,7 @@ async def test_parse_chat_messages_single_image_with_uuid_async(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -516,7 +516,7 @@ async def test_parse_chat_messages_empty_image_with_uuid_async(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -554,7 +554,7 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -595,7 +595,7 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -634,7 +634,7 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -660,7 +660,7 @@ def test_parse_chat_messages_empty_system(
"content": [{"type": "text", "text": "Who are you?"}],
},
],
mistral_model_config,
RendererConfig(model_config=mistral_model_config),
content_format="string",
)
assert conversation == [
......@@ -677,7 +677,7 @@ def test_parse_chat_messages_empty_system(
"content": [{"type": "text", "text": "Who are you?"}],
},
],
mistral_model_config,
RendererConfig(model_config=mistral_model_config),
content_format="openai",
)
assert conversation == [
......@@ -701,7 +701,7 @@ async def test_parse_chat_messages_single_image_async(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -730,7 +730,7 @@ def test_parse_chat_messages_multiple_images(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -758,7 +758,7 @@ def test_parse_chat_messages_empty_pil_image_with_uuid(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -786,7 +786,7 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
],
}
],
phi3v_model_config_image_embeds,
RendererConfig(model_config=phi3v_model_config_image_embeds),
content_format="string",
)
......@@ -818,7 +818,7 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
],
}
],
audio_embeds_model_config,
RendererConfig(model_config=audio_embeds_model_config),
content_format="string",
)
......@@ -858,7 +858,7 @@ def test_parse_chat_messages_audio_embeds_with_string(
],
}
],
audio_embeds_model_config,
RendererConfig(model_config=audio_embeds_model_config),
content_format="string",
)
......@@ -900,7 +900,7 @@ async def test_parse_chat_messages_audio_embeds_async(
],
}
],
audio_embeds_model_config,
RendererConfig(model_config=audio_embeds_model_config),
content_format="string",
)
......@@ -1108,7 +1108,7 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
],
}
],
phi3v_model_config_image_embeds,
RendererConfig(model_config=phi3v_model_config_image_embeds),
content_format="string",
)
......@@ -1144,7 +1144,7 @@ async def test_parse_chat_messages_multiple_images_async(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -1176,7 +1176,7 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
assert conversation == [
......@@ -1208,7 +1208,7 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -1245,7 +1245,7 @@ def test_parse_chat_messages_multiple_images_across_messages(
],
},
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -1289,7 +1289,7 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
],
},
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -1314,7 +1314,7 @@ def test_parse_chat_messages_context_text_format(
{"role": "assistant", "content": "Some stuff."},
{"role": "user", "content": "What about this one?"},
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="openai",
)
......@@ -1367,7 +1367,7 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -1410,7 +1410,7 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
],
},
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -1430,7 +1430,7 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
],
}
],
phi3v_model_config,
RendererConfig(model_config=phi3v_model_config),
content_format="string",
)
......@@ -1464,7 +1464,7 @@ def test_parse_chat_messages_multiple_images_interleave(
],
}
],
phi3v_model_config_mm_interleaved,
RendererConfig(model_config=phi3v_model_config_mm_interleaved),
content_format="string",
)
......@@ -1500,7 +1500,7 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
],
}
],
phi3v_model_config_mm_interleaved,
RendererConfig(model_config=phi3v_model_config_mm_interleaved),
content_format="string",
)
......@@ -1545,7 +1545,7 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
],
}
],
phi3v_model_config_mm_interleaved,
RendererConfig(model_config=phi3v_model_config_mm_interleaved),
content_format="string",
)
......@@ -1583,7 +1583,7 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
],
},
],
phi3v_model_config_mm_interleaved,
RendererConfig(model_config=phi3v_model_config_mm_interleaved),
content_format="string",
)
......@@ -1631,7 +1631,7 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl
],
},
],
phi3v_model_config_mm_interleaved,
RendererConfig(model_config=phi3v_model_config_mm_interleaved),
content_format="string",
)
......@@ -1675,7 +1675,7 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
],
},
],
qwen25omni_model_config_mm_interleaved,
RendererConfig(model_config=qwen25omni_model_config_mm_interleaved),
content_format="string",
)
......@@ -1743,7 +1743,7 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl
],
},
],
qwen25omni_model_config_mm_interleaved,
RendererConfig(model_config=qwen25omni_model_config_mm_interleaved),
content_format="string",
)
......@@ -1813,7 +1813,7 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes
],
},
],
qwen25omni_model_config_mm_interleaved,
RendererConfig(model_config=qwen25omni_model_config_mm_interleaved),
content_format="string",
)
......@@ -1879,7 +1879,7 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message
],
},
],
qwen25omni_model_config_mm_interleaved,
RendererConfig(model_config=qwen25omni_model_config_mm_interleaved),
content_format="string",
)
......@@ -1927,7 +1927,7 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
],
}
],
phi3v_model_config_mm_interleaved,
RendererConfig(model_config=phi3v_model_config_mm_interleaved),
content_format="string",
)
......@@ -1945,24 +1945,11 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
# Build the tokenizer
renderer_config = model_info.build_renderer_config(model)
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
renderer_config.tokenizer,
trust_remote_code=renderer_config.trust_remote_code,
)
tools = (
......@@ -1985,7 +1972,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
tokenizer,
chat_template=None,
tools=tools,
model_config=model_config,
model_config=renderer_config.model_config,
)
assert isinstance(chat_template, str)
......@@ -2047,24 +2034,11 @@ def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, expected_kwa
"enable_thinking": True,
}
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
# Build the tokenizer
renderer_config = model_info.build_renderer_config(model)
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
renderer_config.tokenizer,
trust_remote_code=renderer_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
......@@ -2072,7 +2046,7 @@ def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, expected_kwa
tokenizer,
chat_template=None,
tools=tools,
model_config=model_config,
model_config=renderer_config.model_config,
)
with pytest.raises(
ValueError, match="Found unexpected chat template kwargs from request"
......@@ -2143,23 +2117,11 @@ def test_resolve_content_format_hf_defined(model, expected_format):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
renderer_config = model_info.build_renderer_config(model)
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
renderer_config.tokenizer,
trust_remote_code=renderer_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
......@@ -2167,7 +2129,7 @@ def test_resolve_content_format_hf_defined(model, expected_format):
tokenizer,
chat_template=None,
tools=None,
model_config=model_config,
model_config=renderer_config.model_config,
)
assert isinstance(chat_template, str)
......@@ -2181,7 +2143,7 @@ def test_resolve_content_format_hf_defined(model, expected_format):
None,
"auto",
tokenizer,
model_config=model_config,
renderer_config=renderer_config,
)
assert resolved_format == expected_format
......@@ -2203,23 +2165,11 @@ def test_resolve_content_format_fallbacks(model, expected_format):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
renderer_config = model_info.build_renderer_config(model)
tokenizer = get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
renderer_config.tokenizer,
trust_remote_code=renderer_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
......@@ -2227,7 +2177,7 @@ def test_resolve_content_format_fallbacks(model, expected_format):
tokenizer,
chat_template=None,
tools=None,
model_config=model_config,
model_config=renderer_config.model_config,
)
assert isinstance(chat_template, str)
......@@ -2241,7 +2191,7 @@ def test_resolve_content_format_fallbacks(model, expected_format):
None,
"auto",
tokenizer,
model_config=model_config,
renderer_config=renderer_config,
)
assert resolved_format == expected_format
......@@ -2272,15 +2222,13 @@ def test_resolve_content_format_fallbacks(model, expected_format):
],
)
def test_resolve_content_format_examples(template_path, expected_format):
model_config = ModelConfig(
PHI3V_MODEL_ID, # Dummy
tokenizer=PHI3V_MODEL_ID, # Dummy
trust_remote_code=True,
)
model = PHI3V_MODEL_ID # Dummy
model_config = ModelConfig(model, trust_remote_code=True)
renderer_config = RendererConfig(model_config=model_config, tokenizer=model)
dummy_tokenizer = get_tokenizer(
PHI3V_MODEL_ID, # Dummy
trust_remote_code=model_config.trust_remote_code,
renderer_config.tokenizer,
trust_remote_code=renderer_config.trust_remote_code,
)
dummy_tokenizer.chat_template = None
......@@ -2297,7 +2245,7 @@ def test_resolve_content_format_examples(template_path, expected_format):
None,
"auto",
dummy_tokenizer,
model_config=model_config,
renderer_config=renderer_config,
)
assert resolved_format == expected_format
......@@ -2332,7 +2280,7 @@ def test_parse_chat_messages_include_thinking_chunk(mistral_model_config):
conversation_with_thinking, _, _ = parse_chat_messages(
messages,
mistral_model_config,
RendererConfig(model_config=mistral_model_config),
content_format="openai",
)
......@@ -2432,7 +2380,7 @@ def test_parse_chat_messages_single_empty_audio_with_uuid(
],
}
],
qwen2_audio_model_config,
RendererConfig(model_config=qwen2_audio_model_config),
content_format="string",
)
......@@ -2466,7 +2414,7 @@ async def test_parse_chat_messages_single_empty_audio_with_uuid_async(
],
}
],
qwen2_audio_model_config,
RendererConfig(model_config=qwen2_audio_model_config),
content_format="string",
)
......
......@@ -8,7 +8,7 @@ import torch
from safetensors.torch import load_file
from torch import nn
from vllm.config import ModelConfig, VllmConfig
from vllm.config import ModelConfig, RendererConfig, VllmConfig
from vllm.config.lora import LoRAConfig
from vllm.lora.layers import (
ColumnParallelLinearWithLoRA,
......@@ -422,7 +422,11 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
)
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
vllm_config = VllmConfig(
model_config=model_config,
renderer_config=RendererConfig(model_config=model_config),
lora_config=lora_config,
)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
......@@ -525,7 +529,11 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
)
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
vllm_config = VllmConfig(
model_config=model_config,
renderer_config=RendererConfig(model_config=model_config),
lora_config=lora_config,
)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
......
......@@ -11,6 +11,7 @@ from vllm.config import (
DeviceConfig,
ModelConfig,
ParallelConfig,
RendererConfig,
SchedulerConfig,
VllmConfig,
)
......@@ -43,6 +44,7 @@ def test_worker_apply_lora(qwen3_lora_files):
vllm_config = VllmConfig(
model_config=model_config,
renderer_config=RendererConfig(model_config=model_config),
load_config=LoadConfig(
download_dir=None,
load_format="dummy",
......
......@@ -42,8 +42,10 @@ def test_model_loading_with_params(vllm_runner, monkeypatch):
"Write a short story about a robot that dreams for the first time.\n"
)
model_config = vllm_model.llm.llm_engine.model_config
model_tokenizer = vllm_model.llm.llm_engine.tokenizer
llm_engine = vllm_model.llm.llm_engine
model_config = llm_engine.model_config
renderer_config = llm_engine.renderer_config
tokenizer = llm_engine.tokenizer
# asserts on the bert model config file
assert model_config.encoder_config["max_seq_length"] == 512
......@@ -54,8 +56,8 @@ def test_model_loading_with_params(vllm_runner, monkeypatch):
assert model_config.pooler_config.normalize
# asserts on the tokenizer loaded
assert model_config.tokenizer == "BAAI/bge-base-en-v1.5"
assert model_tokenizer.model_max_length == 512
assert renderer_config.tokenizer == "BAAI/bge-base-en-v1.5"
assert tokenizer.model_max_length == 512
def check_model(model):
assert isinstance(model, BertEmbeddingModel)
......@@ -86,8 +88,10 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch):
"Write a short story about a robot that dreams for the first time.\n"
)
model_config = vllm_model.llm.llm_engine.model_config
model_tokenizer = vllm_model.llm.llm_engine.tokenizer
llm_engine = vllm_model.llm.llm_engine
model_config = llm_engine.model_config
renderer_config = llm_engine.renderer_config
tokenizer = llm_engine.tokenizer
# asserts on the bert model config file
assert model_config.encoder_config["max_seq_length"] == 512
......@@ -98,8 +102,8 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch):
assert model_config.pooler_config.normalize
# asserts on the tokenizer loaded
assert model_config.tokenizer == "intfloat/multilingual-e5-base"
assert model_tokenizer.model_max_length == 512
assert renderer_config.tokenizer == "intfloat/multilingual-e5-base"
assert tokenizer.model_max_length == 512
def check_model(model):
assert isinstance(model, RobertaEmbeddingModel)
......@@ -128,7 +132,7 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner, monkeypatch):
"Write a short story about a robot that dreams for the first time.\n"
)
assert vllm_model.llm.llm_engine.model_config.tokenizer == model_name
assert vllm_model.llm.llm_engine.renderer_config.tokenizer == model_name
def check_model(model):
assert isinstance(model, RobertaEmbeddingModel)
......
......@@ -6,7 +6,7 @@ import pytest
from scipy.spatial.distance import cosine
from vllm import LLM, SamplingParams
from vllm.config import ModelConfig
from vllm.config import ModelConfig, RendererConfig
from ....utils import RemoteOpenAIServer
......@@ -31,7 +31,8 @@ def test_find_array():
dtype="bfloat16",
seed=0,
)
pooling = GritLMMeanPool(model_config=model_config)
renderer_config = RendererConfig(model_config=model_config)
pooling = GritLMMeanPool(renderer_config=renderer_config)
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
......
......@@ -25,7 +25,6 @@ from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingC
from vllm.tokenizers import (
MistralTokenizer,
TokenizerLike,
cached_tokenizer_from_config,
)
from ....multimodal.utils import random_audio, random_image, random_video
......@@ -212,31 +211,20 @@ def _test_processing_correctness(
else:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id_or_arch)
model_id = model_id_or_arch
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
model_config = ModelConfig(
model_id,
tokenizer=model_info.tokenizer or model_id,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
renderer_config = model_info.build_renderer_config(
model=model_id,
# Ensure that the cache can fit all of the data
mm_processor_cache_gb=2048,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
model_config = renderer_config.model_config
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = model_cls._processor_factory
ctx = InputProcessingContext(
model_config,
tokenizer=cached_tokenizer_from_config(model_config),
)
ctx = InputProcessingContext.from_config(renderer_config)
cache = MultiModalProcessorOnlyCache(model_config)
processing_info = factories.info(ctx)
......
......@@ -40,7 +40,7 @@ def test_processor_override(
mm_processor_kwargs=None,
limit_mm_per_prompt={"video": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.renderer_config)
tokenizer = processor.info.get_tokenizer()
hf_processor_mm_kwargs = {"fps": fps}
......@@ -79,7 +79,7 @@ def test_video_loader_consistency(
mm_processor_kwargs=None,
limit_mm_per_prompt={"video": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.renderer_config)
hf_processor_mm_kwargs = {"fps": fps}
# Build the image str / prompt based on the number of images we pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment