Unverified Commit ae09b929 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

ci: apply static type check to vllm multimodal handlers (#6027)

parent 20ccc9b2
...@@ -81,7 +81,7 @@ class EncodeWorkerHandler: ...@@ -81,7 +81,7 @@ class EncodeWorkerHandler:
self.vision_encoder, self.projector = get_encoder_components( self.vision_encoder, self.projector = get_encoder_components(
self.model, self.vision_model self.model, self.vision_model
) )
self._connector = None self._connector: connect.Connector | None = None
self._accumulated_time = 0.0 self._accumulated_time = 0.0
self._processed_requests = 0 self._processed_requests = 0
self.readables = [] self.readables = []
...@@ -253,6 +253,9 @@ class EncodeWorkerHandler: ...@@ -253,6 +253,9 @@ class EncodeWorkerHandler:
request.multimodal_inputs[idx].serialized_request = cache_path request.multimodal_inputs[idx].serialized_request = cache_path
else: else:
descriptor = connect.Descriptor(embedding_item.embeddings_cpu) descriptor = connect.Descriptor(embedding_item.embeddings_cpu)
assert (
self._connector is not None
), "Connector not initialized; call async_init() first"
self.readables.append( self.readables.append(
await self._connector.create_readable(descriptor) await self._connector.create_readable(descriptor)
) )
......
...@@ -158,10 +158,10 @@ class PreprocessedHandler(ProcessMixIn): ...@@ -158,10 +158,10 @@ class PreprocessedHandler(ProcessMixIn):
for encode_res in encode_res_gen: for encode_res in encode_res_gen:
async for response in encode_res: async for response in encode_res:
logger.debug(f"Received response from encode worker: {response}") logger.debug(f"Received response from encode worker: {response}")
output = vLLMMultimodalRequest.model_validate_json(response.data()) output = vLLMMultimodalRequest.model_validate_json(response.data()) # type: ignore[attr-defined]
worker_request.multimodal_inputs.extend(output.multimodal_inputs) worker_request.multimodal_inputs.extend(output.multimodal_inputs)
response_generator = await self.pd_worker_client.round_robin( response_generator = await self.pd_worker_client.round_robin( # type: ignore[call-arg]
worker_request.model_dump_json(), context=context worker_request.model_dump_json(), context=context
) )
...@@ -180,7 +180,7 @@ class PreprocessedHandler(ProcessMixIn): ...@@ -180,7 +180,7 @@ class PreprocessedHandler(ProcessMixIn):
async for resp in response_generator: async for resp in response_generator:
# Deserialize the response from the engine # Deserialize the response from the engine
# Creates correct vLLM objects for each field # Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data()) output = MyRequestOutput.model_validate_json(resp.data()) # type: ignore[attr-defined]
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object # OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
res = RequestOutput( res = RequestOutput(
...@@ -287,7 +287,7 @@ class ECProcessorHandler(PreprocessedHandler): ...@@ -287,7 +287,7 @@ class ECProcessorHandler(PreprocessedHandler):
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
encoder_worker_client: Client, encoder_worker_client: Client,
pd_worker_client: Client, pd_worker_client: Client,
prompt_template: str = None, prompt_template: str | None = None,
): ):
""" """
Initialize the ECConnector processor. Initialize the ECConnector processor.
...@@ -398,7 +398,7 @@ class ECProcessorHandler(PreprocessedHandler): ...@@ -398,7 +398,7 @@ class ECProcessorHandler(PreprocessedHandler):
) )
# Send single request to PD worker with ALL images # Send single request to PD worker with ALL images
response_generator = await self.pd_worker_client.round_robin( response_generator = await self.pd_worker_client.round_robin( # type: ignore[call-arg]
worker_request.model_dump_json(), context=context worker_request.model_dump_json(), context=context
) )
......
...@@ -122,7 +122,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -122,7 +122,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
component: Component, component: Component,
engine_client: AsyncLLM, engine_client: AsyncLLM,
config, config,
decode_worker_client: Client = None, decode_worker_client: Client | None = None,
shutdown_event=None, shutdown_event=None,
): ):
# Get default_sampling_params from config # Get default_sampling_params from config
...@@ -157,7 +157,9 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -157,7 +157,9 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
# Create and initialize a dynamo connector for this worker. # Create and initialize a dynamo connector for this worker.
# We'll need this to move data between this worker and remote workers efficiently. # We'll need this to move data between this worker and remote workers efficiently.
# Note: This is synchronous initialization, async initialization happens in async_init # Note: This is synchronous initialization, async initialization happens in async_init
self._connector = None # Will be initialized in async_init self._connector: connect.Connector | None = (
None # Will be initialized in async_init
)
self.image_loader = ImageLoader() self.image_loader = ImageLoader()
logger.info("Multimodal PD Worker has been initialized") logger.info("Multimodal PD Worker has been initialized")
...@@ -279,7 +281,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -279,7 +281,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
) in await self.decode_worker_client.round_robin( ) in await self.decode_worker_client.round_robin(
decode_request.model_dump_json() decode_request.model_dump_json()
): ):
output = MyRequestOutput.model_validate_json(decode_response.data()) output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore[attr-defined]
yield MyRequestOutput( yield MyRequestOutput(
request_id=output.request_id, request_id=output.request_id,
prompt=output.prompt, prompt=output.prompt,
......
...@@ -163,7 +163,6 @@ addopts = [ ...@@ -163,7 +163,6 @@ addopts = [
"--ignore-glob=components/src/dynamo/sglang/request_handlers/*", "--ignore-glob=components/src/dynamo/sglang/request_handlers/*",
"--ignore-glob=components/src/dynamo/sglang/multimodal_utils/*", "--ignore-glob=components/src/dynamo/sglang/multimodal_utils/*",
"--ignore-glob=components/src/dynamo/vllm/multimodal_utils/*", "--ignore-glob=components/src/dynamo/vllm/multimodal_utils/*",
"--ignore-glob=components/src/dynamo/vllm/multimodal_handlers/*",
"--ignore-glob=examples/backends/sglang/slurm_jobs/*", "--ignore-glob=examples/backends/sglang/slurm_jobs/*",
# FIXME: Get relative/generic blob paths to work here # FIXME: Get relative/generic blob paths to work here
] ]
...@@ -189,6 +188,12 @@ filterwarnings = [ ...@@ -189,6 +188,12 @@ filterwarnings = [
# pytest-benchmark automatically disables when xdist is active, ignore the warning # pytest-benchmark automatically disables when xdist is active, ignore the warning
"ignore:.*Benchmarks are automatically disabled.*:pytest_benchmark.logger.PytestBenchmarkWarning", "ignore:.*Benchmarks are automatically disabled.*:pytest_benchmark.logger.PytestBenchmarkWarning",
################################################################################################
# vLLM
################################################################################################
# vLLM tokenizer deprecation warning (AnyTokenizer moved to vllm.tokenizers.TokenizerLike)
"ignore:.*vllm\\.transformers_utils\\.tokenizer\\.AnyTokenizer.*has been moved.*:DeprecationWarning",
################################################################################################ ################################################################################################
# TRT-LLM # TRT-LLM
################################################################################################ ################################################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment