Unverified Commit 2f2212e6 authored by Christian Pinto's avatar Christian Pinto Committed by GitHub
Browse files

Split generic IO Processor plugins tests from Terratorch specific ones (#35756)


Signed-off-by: default avatarChristian Pinto <christian.pinto@ibm.com>
parent 18e01a0a
......@@ -15,9 +15,12 @@ steps:
- pytest -v -s plugins_tests/test_platform_plugins.py
- pip uninstall vllm_add_dummy_platform -y
# end platform plugin tests
# begin io_processor plugins test, all the code in between uses the prithvi_io_processor plugin
# begin io_processor plugins test
# test generic io_processor plugins functions
- pytest -v -s ./plugins_tests/test_io_processor_plugins.py
# test Terratorch io_processor plugins
- pip install -e ./plugins/prithvi_io_processor_plugin
- pytest -v -s plugins_tests/test_io_processor_plugins.py
- pytest -v -s plugins_tests/test_terratorch_io_processor_plugins.py
- pip uninstall prithvi_io_processor_plugin -y
# test bge_m3_sparse io_processor plugin
- pip install -e ./plugins/bge_m3_sparse_plugin
......
......@@ -1140,6 +1140,15 @@ class VllmRunner:
return self
def __exit__(self, exc_type, exc_value, traceback):
# Explicitly shutdown the engine core to release GPU resources
# This is needed because when executing consecutive tests, the GC
# might not be fast enough in shutting down the llm engine. This can lead to OOMs
# because when the next test starts some GPU memory is still in use.
try:
self.llm.llm_engine.engine_core.shutdown()
except Exception:
# Ignore shutdown errors as cleanup will still proceed
pass
del self.llm
cleanup_dist_env_and_memory()
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import io
from collections.abc import Sequence
from unittest.mock import MagicMock, patch
import imagehash
import pytest
import requests
from PIL import Image
from tests.utils import RemoteOpenAIServer
from vllm.config import VllmConfig
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors import get_io_processor
models_config = {
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11": {
"image_url": "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff", # noqa: E501
"out_hash": "aa6d92ad25926a5e",
"plugin": "prithvi_to_tiff",
},
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars": {
"image_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars/resolve/main/examples/subsetted_512x512_HLS.S30.T10SEH.2018190.v1.4_merged.tif", # noqa: E501
"out_hash": "c07f4f602da73552",
"plugin": "prithvi_to_tiff",
},
}
def _compute_image_hash(base64_data: str) -> str:
# Decode the base64 output and create image from byte stream
decoded_image = base64.b64decode(base64_data)
image = Image.open(io.BytesIO(decoded_image))
# Compute perceptual hash of the output image
return str(imagehash.phash(image))
from vllm.plugins.io_processors.interface import IOProcessor
from vllm.renderers import BaseRenderer
class DummyIOProcessor(IOProcessor):
"""Minimal IOProcessor used as the target of the mocked plugin entry point."""
def pre_process(
self,
prompt: object,
request_id: str | None = None,
**kwargs,
) -> PromptType | Sequence[PromptType]:
raise NotImplementedError
def post_process(
self,
model_output: Sequence[PoolingRequestOutput],
request_id: str | None = None,
**kwargs,
) -> object:
raise NotImplementedError
@pytest.fixture
def my_plugin_entry_points():
"""Patch importlib.metadata.entry_points to expose a single 'my_plugin'
entry point backed by DummyIOProcessor, exercising the full plugin-loading
code path: entry_points → plugin.load() → func() →
resolve_obj_by_qualname → IOProcessor.__init__."""
qualname = f"{DummyIOProcessor.__module__}.{DummyIOProcessor.__qualname__}"
ep = MagicMock()
ep.name = "my_plugin"
ep.value = qualname
ep.load.return_value = lambda: qualname
with patch("importlib.metadata.entry_points", return_value=[ep]):
yield
def test_loading_missing_plugin():
vllm_config = VllmConfig()
renderer = MagicMock(spec=BaseRenderer)
with pytest.raises(ValueError):
get_io_processor(vllm_config, None, "wrong_plugin")
@pytest.fixture(scope="function")
def server(model_name, plugin):
args = [
"--runner",
"pooling",
"--enforce-eager",
"--skip-tokenizer-init",
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
"--max-num-seqs",
"32",
"--io-processor-plugin",
plugin,
"--enable-mm-embeds",
]
with RemoteOpenAIServer(model_name, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, image_url, plugin, expected_hash",
[
(model_name, config["image_url"], config["plugin"], config["out_hash"])
for model_name, config in models_config.items()
],
)
async def test_prithvi_mae_plugin_online(
server: RemoteOpenAIServer,
model_name: str,
image_url: str | dict,
plugin: str,
expected_hash: str,
):
request_payload_url = {
"data": {
"data": image_url,
"data_format": "url",
"image_format": "tiff",
"out_data_format": "b64_json",
},
"priority": 0,
"model": model_name,
"softmax": False,
}
ret = requests.post(
server.url_for("pooling"),
json=request_payload_url,
get_io_processor(
vllm_config, renderer=renderer, plugin_from_init="wrong_plugin"
)
response = ret.json()
# verify the request response is in the correct format
assert (parsed_response := IOProcessorResponse(**response))
def test_loading_plugin(my_plugin_entry_points):
# Plugin name supplied via plugin_from_init.
vllm_config = MagicMock(spec=VllmConfig)
renderer = MagicMock(spec=BaseRenderer)
# verify the output is formatted as expected for this plugin
plugin_data = parsed_response.data
assert all(plugin_data.get(attr) for attr in ["type", "format", "data"])
# Compute the output image hash and compare it against the expected hash
image_hash = _compute_image_hash(plugin_data["data"])
assert image_hash == expected_hash, (
f"Image hash mismatch: expected {expected_hash}, got {image_hash}"
result = get_io_processor(
vllm_config, renderer=renderer, plugin_from_init="my_plugin"
)
assert isinstance(result, DummyIOProcessor)
@pytest.mark.parametrize(
"model_name, image_url, plugin, expected_hash",
[
(model_name, config["image_url"], config["plugin"], config["out_hash"])
for model_name, config in models_config.items()
],
)
def test_prithvi_mae_plugin_offline(
vllm_runner, model_name: str, image_url: str | dict, plugin: str, expected_hash: str
):
img_data = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
prompt = dict(data=img_data)
with vllm_runner(
model_name,
runner="pooling",
skip_tokenizer_init=True,
enable_mm_embeds=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=32,
io_processor_plugin=plugin,
default_torch_num_threads=1,
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(prompt, pooling_task="plugin")
output = pooler_output[0].outputs
# verify the output is formatted as expected for this plugin
assert all(hasattr(output, attr) for attr in ["type", "format", "data"])
# Compute the output image hash and compare it against the expected hash
image_hash = _compute_image_hash(output.data)
assert image_hash == expected_hash, (
f"Image hash mismatch: expected {expected_hash}, got {image_hash}"
)
def test_loading_missing_plugin_from_model_config():
# Build a mock VllmConfig whose hf_config advertises a plugin name,
# exercising the model-config code path without loading a real model.
mock_hf_config = MagicMock()
mock_hf_config.to_dict.return_value = {"io_processor_plugin": "wrong_plugin"}
vllm_config = MagicMock(spec=VllmConfig)
vllm_config.model_config.hf_config = mock_hf_config
renderer = MagicMock(spec=BaseRenderer)
with pytest.raises(ValueError):
get_io_processor(vllm_config, renderer=renderer)
def test_loading_plugin_from_model_config(my_plugin_entry_points):
# Plugin name supplied via the model's hf_config.
mock_hf_config = MagicMock()
mock_hf_config.to_dict.return_value = {"io_processor_plugin": "my_plugin"}
vllm_config = MagicMock(spec=VllmConfig)
vllm_config.model_config.hf_config = mock_hf_config
renderer = MagicMock(spec=BaseRenderer)
result = get_io_processor(vllm_config, renderer=renderer)
assert isinstance(result, DummyIOProcessor)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import io
import imagehash
import pytest
import requests
from PIL import Image
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
models_config = {
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11": {
"image_url": "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff", # noqa: E501
"out_hash": "aa6d92ad25926a5e",
"plugin": "prithvi_to_tiff",
},
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars": {
"image_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars/resolve/main/examples/subsetted_512x512_HLS.S30.T10SEH.2018190.v1.4_merged.tif", # noqa: E501
"out_hash": "c07f4f602da73552",
"plugin": "prithvi_to_tiff",
},
}
def _compute_image_hash(base64_data: str) -> str:
# Decode the base64 output and create image from byte stream
decoded_image = base64.b64decode(base64_data)
image = Image.open(io.BytesIO(decoded_image))
# Compute perceptual hash of the output image
return str(imagehash.phash(image))
@pytest.fixture(scope="function")
def server(model_name, plugin):
args = [
"--runner",
"pooling",
"--enforce-eager",
"--skip-tokenizer-init",
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
"--max-num-seqs",
"32",
"--io-processor-plugin",
plugin,
"--enable-mm-embeds",
]
with RemoteOpenAIServer(model_name, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, image_url, plugin, expected_hash",
[
(model_name, config["image_url"], config["plugin"], config["out_hash"])
for model_name, config in models_config.items()
],
)
async def test_prithvi_mae_plugin_online(
server: RemoteOpenAIServer,
model_name: str,
image_url: str | dict,
plugin: str,
expected_hash: str,
):
request_payload_url = {
"data": {
"data": image_url,
"data_format": "url",
"image_format": "tiff",
"out_data_format": "b64_json",
},
"priority": 0,
"model": model_name,
"softmax": False,
}
ret = requests.post(
server.url_for("pooling"),
json=request_payload_url,
)
response = ret.json()
# verify the request response is in the correct format
assert (parsed_response := IOProcessorResponse(**response))
# verify the output is formatted as expected for this plugin
plugin_data = parsed_response.data
assert all(plugin_data.get(attr) for attr in ["type", "format", "data"])
# Compute the output image hash and compare it against the expected hash
image_hash = _compute_image_hash(plugin_data["data"])
assert image_hash == expected_hash, (
f"Image hash mismatch: expected {expected_hash}, got {image_hash}"
)
@pytest.mark.parametrize(
"model_name, image_url, plugin, expected_hash",
[
(model_name, config["image_url"], config["plugin"], config["out_hash"])
for model_name, config in models_config.items()
],
)
def test_prithvi_mae_plugin_offline(
vllm_runner, model_name: str, image_url: str | dict, plugin: str, expected_hash: str
):
img_data = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
prompt = dict(data=img_data)
with vllm_runner(
model_name,
runner="pooling",
skip_tokenizer_init=True,
enable_mm_embeds=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=32,
io_processor_plugin=plugin,
default_torch_num_threads=1,
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(prompt, pooling_task="plugin")
output = pooler_output[0].outputs
# verify the output is formatted as expected for this plugin
assert all(hasattr(output, attr) for attr in ["type", "format", "data"])
# Compute the output image hash and compare it against the expected hash
image_hash = _compute_image_hash(output.data)
assert image_hash == expected_hash, (
f"Image hash mismatch: expected {expected_hash}, got {image_hash}"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment