Unverified Commit 51d5e9be authored by mgazz's avatar mgazz Committed by GitHub
Browse files

[Core][Model] Terratorch backend integration (#23513)


Signed-off-by: default avatarMichele Gazzetti <michele.gazzetti1@ibm.com>
Signed-off-by: default avatarChristian Pinto <christian.pinto@ibm.com>
Co-authored-by: default avatarChristian Pinto <christian.pinto@ibm.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent e7fc7001
...@@ -45,7 +45,11 @@ datamodule_config = { ...@@ -45,7 +45,11 @@ datamodule_config = {
class PrithviMAE: class PrithviMAE:
def __init__(self, model): def __init__(self, model):
self.model = LLM( self.model = LLM(
model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True model=model,
skip_tokenizer_init=True,
dtype="float16",
enforce_eager=True,
model_impl="terratorch",
) )
def run(self, input_data, location_coords): def run(self, input_data, location_coords):
......
...@@ -37,6 +37,7 @@ def main(): ...@@ -37,6 +37,7 @@ def main():
# The maximum number depends on the available GPU memory # The maximum number depends on the available GPU memory
max_num_seqs=32, max_num_seqs=32,
io_processor_plugin="prithvi_to_tiff_india", io_processor_plugin="prithvi_to_tiff_india",
model_impl="terratorch",
) )
pooling_params = PoolingParams(task="encode", softmax=False) pooling_params = PoolingParams(task="encode", softmax=False)
......
...@@ -15,6 +15,7 @@ import requests ...@@ -15,6 +15,7 @@ import requests
# https://github.com/christian-pinto/prithvi_io_processor_plugin # https://github.com/christian-pinto/prithvi_io_processor_plugin
# - start vllm in serving mode with the below args # - start vllm in serving mode with the below args
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM' # --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
# --model-impl terratorch
# --task embed --trust-remote-code # --task embed --trust-remote-code
# --skip-tokenizer-init --enforce-eager # --skip-tokenizer-init --enforce-eager
# --io-processor-plugin prithvi_to_tiff_india # --io-processor-plugin prithvi_to_tiff_india
......
...@@ -53,5 +53,5 @@ runai-model-streamer==0.11.0 ...@@ -53,5 +53,5 @@ runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0 runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10 fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10 pydantic>=2.10 # 2.9 leads to error on python 3.10
terratorch==1.1rc2 # required for PrithviMAE test
decord==0.6.0 decord==0.6.0
terratorch==1.1rc3 # required for PrithviMAE test
...@@ -1042,7 +1042,7 @@ tensorboardx==2.6.4 ...@@ -1042,7 +1042,7 @@ tensorboardx==2.6.4
# via lightning # via lightning
tensorizer==2.10.1 tensorizer==2.10.1
# via -r requirements/test.in # via -r requirements/test.in
terratorch==1.1rc2 terratorch==1.1rc3
# via -r requirements/test.in # via -r requirements/test.in
threadpoolctl==3.5.0 threadpoolctl==3.5.0
# via scikit-learn # via scikit-learn
......
...@@ -298,6 +298,8 @@ def _compare_tp( ...@@ -298,6 +298,8 @@ def _compare_tp(
tokenizer_mode = model_info.tokenizer_mode tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides hf_overrides = model_info.hf_overrides
hf_config = get_config(model_id, trust_remote_code) hf_config = get_config(model_id, trust_remote_code)
skip_tokenizer_init = model_info.skip_tokenizer_init
max_num_seqs = model_info.max_num_seqs
dtype = "float16" dtype = "float16"
if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS: if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS:
...@@ -351,6 +353,10 @@ def _compare_tp( ...@@ -351,6 +353,10 @@ def _compare_tp(
common_args.extend(["--load-format", load_format]) common_args.extend(["--load-format", load_format])
if hf_overrides: if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if skip_tokenizer_init:
common_args.append("--skip-tokenizer-init")
if max_num_seqs:
common_args.extend(["--max-num-seqs", f"{max_num_seqs}"])
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
testing_ray_compiled_graph = False testing_ray_compiled_graph = False
......
...@@ -178,6 +178,7 @@ def _compare_sp( ...@@ -178,6 +178,7 @@ def _compare_sp(
trust_remote_code = model_info.trust_remote_code trust_remote_code = model_info.trust_remote_code
tokenizer_mode = model_info.tokenizer_mode tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides hf_overrides = model_info.hf_overrides
skip_tokenizer_init = model_info.skip_tokenizer_init
if load_format == "dummy": if load_format == "dummy":
# Avoid OOM # Avoid OOM
...@@ -227,6 +228,8 @@ def _compare_sp( ...@@ -227,6 +228,8 @@ def _compare_sp(
common_args.extend(["--load-format", load_format]) common_args.extend(["--load-format", load_format])
if hf_overrides: if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if skip_tokenizer_init:
common_args.append("--skip-tokenizer-init")
compilation_config = { compilation_config = {
'level': 3, 'level': 3,
......
...@@ -104,7 +104,9 @@ def test_get_gen_prompt(model, template, add_generation_prompt, ...@@ -104,7 +104,9 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
revision=model_info.revision, revision=model_info.revision,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
) skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
# Initialize the tokenizer # Initialize the tokenizer
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
......
...@@ -11,7 +11,7 @@ import torch ...@@ -11,7 +11,7 @@ import torch
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM" MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
DTYPE = "float16" DTYPE = "float16"
...@@ -35,7 +35,9 @@ def server(): ...@@ -35,7 +35,9 @@ def server():
"--trust-remote-code", "--trust-remote-code",
"--skip-tokenizer-init", "--skip-tokenizer-init",
"--max-num-seqs", "--max-num-seqs",
"32" "32",
"--model-impl",
"terratorch"
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
......
...@@ -1266,7 +1266,9 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): ...@@ -1266,7 +1266,9 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
) skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
# Build the tokenizer group and grab the underlying tokenizer # Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
...@@ -1322,7 +1324,9 @@ def test_resolve_content_format_hf_defined(model, expected_format): ...@@ -1322,7 +1324,9 @@ def test_resolve_content_format_hf_defined(model, expected_format):
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
) skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
model, model,
...@@ -1382,7 +1386,9 @@ def test_resolve_content_format_fallbacks(model, expected_format): ...@@ -1382,7 +1386,9 @@ def test_resolve_content_format_fallbacks(model, expected_format):
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
) skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
model_config.tokenizer, model_config.tokenizer,
......
...@@ -69,6 +69,9 @@ def run_test( ...@@ -69,6 +69,9 @@ def run_test(
vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode
if model_info.hf_overrides: if model_info.hf_overrides:
vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides
if model_info.skip_tokenizer_init:
vllm_runner_kwargs_[
"skip_tokenizer_init"] = model_info.skip_tokenizer_init
if vllm_runner_kwargs: if vllm_runner_kwargs:
vllm_runner_kwargs_.update(vllm_runner_kwargs) vllm_runner_kwargs_.update(vllm_runner_kwargs)
......
...@@ -46,7 +46,7 @@ def _run_test( ...@@ -46,7 +46,7 @@ def _run_test(
vllm_model.encode(prompt) vllm_model.encode(prompt)
MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
@pytest.mark.core_model @pytest.mark.core_model
......
...@@ -66,7 +66,9 @@ def _test_processing_correctness( ...@@ -66,7 +66,9 @@ def _test_processing_correctness(
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
# Ensure that the cache can fit all of the data # Ensure that the cache can fit all of the data
mm_processor_cache_gb=2048, mm_processor_cache_gb=2048,
) skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
......
...@@ -196,7 +196,9 @@ def test_model_tensor_schema(model_arch: str, model_id: str): ...@@ -196,7 +196,9 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=hf_overrides_fn, hf_overrides=hf_overrides_fn,
) skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
......
...@@ -59,7 +59,9 @@ def test_hf_model_weights_mapper(model_arch: str): ...@@ -59,7 +59,9 @@ def test_hf_model_weights_mapper(model_arch: str):
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
) skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
original_weights = create_repo_dummy_weights(model_id) original_weights = create_repo_dummy_weights(model_id)
......
...@@ -6,10 +6,11 @@ from dataclasses import dataclass, field ...@@ -6,10 +6,11 @@ from dataclasses import dataclass, field
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
import pytest import pytest
import torch
from packaging.version import Version from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.config import TokenizerMode from vllm.config import ModelDType, TokenizerMode
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -47,6 +48,23 @@ class _HfExamplesInfo: ...@@ -47,6 +48,23 @@ class _HfExamplesInfo:
The reason for the minimum/maximum version requirement. The reason for the minimum/maximum version requirement.
""" """
skip_tokenizer_init: bool = False
"""
If true, skip initialization of tokenizer and detokenizer.
"""
dtype: ModelDType = "auto"
"""
The data type for the model weights and activations.
"""
enforce_eager: bool = False
"""
Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
"""
is_available_online: bool = True is_available_online: bool = True
""" """
Set this to ``False`` if the name of this architecture no longer exists on Set this to ``False`` if the name of this architecture no longer exists on
...@@ -76,6 +94,9 @@ class _HfExamplesInfo: ...@@ -76,6 +94,9 @@ class _HfExamplesInfo:
If not specified, the default revision will be used. If not specified, the default revision will be used.
""" """
max_num_seqs: Optional[int] = None
"""Maximum number of sequences to be processed in a single iteration."""
def check_transformers_version( def check_transformers_version(
self, self,
*, *,
...@@ -361,8 +382,21 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -361,8 +382,21 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
trust_remote_code=True), trust_remote_code=True),
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 "PrithviGeoSpatialMAE": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
is_available_online=False), # noqa: E501 dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# This is to avoid the model
# going OOM in CI
max_num_seqs=32,
),
"Terratorch": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# This is to avoid the model going OOM in CI
max_num_seqs=32,
),
} }
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
......
...@@ -73,6 +73,9 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, ...@@ -73,6 +73,9 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
tokenizer=model_info.tokenizer, tokenizer=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode, tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision, revision=model_info.revision,
enforce_eager=model_info.enforce_eager,
skip_tokenizer_init=model_info.skip_tokenizer_init,
dtype=model_info.dtype,
speculative_config={ speculative_config={
"model": model_info.speculative_model, "model": model_info.speculative_model,
"num_speculative_tokens": 1, "num_speculative_tokens": 1,
...@@ -85,7 +88,7 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, ...@@ -85,7 +88,7 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
model_impl=ModelImpl.TRANSFORMERS model_impl=ModelImpl.TRANSFORMERS
if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM, if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM,
hf_overrides=hf_overrides_fn, hf_overrides=hf_overrides_fn,
) max_num_seqs=model_info.max_num_seqs)
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) @pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.conftest import VllmRunner
from vllm.utils import set_default_torch_num_threads
@pytest.mark.parametrize(
"model",
[
"mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
"mgazz/Prithvi_v2_eo_300_tl_unet_agb"
],
)
def test_inference(
vllm_runner: type[VllmRunner],
model: str,
) -> None:
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
prompt = dict(prompt_token_ids=[1],
multi_modal_data=dict(pixel_values=pixel_values,
location_coords=location_coords))
with (
set_default_torch_num_threads(1),
vllm_runner(
model,
runner="pooling",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# Limit the maximum number of sequences to avoid the
# test going OOM during the warmup run
max_num_seqs=32,
) as vllm_model,
):
vllm_output = vllm_model.llm.encode(prompt)
assert torch.equal(
torch.isnan(vllm_output[0].outputs.data).any(),
torch.tensor(False))
...@@ -294,6 +294,8 @@ def build_model_context( ...@@ -294,6 +294,8 @@ def build_model_context(
limit_mm_per_prompt=limit_mm_per_prompt, limit_mm_per_prompt=limit_mm_per_prompt,
mm_processor_cache_gb=mm_processor_cache_gb, mm_processor_cache_gb=mm_processor_cache_gb,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
**model_config_kwargs, **model_config_kwargs,
) )
return InputContext(model_config) return InputContext(model_config)
......
...@@ -7,12 +7,11 @@ import requests ...@@ -7,12 +7,11 @@ import requests
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.entrypoints.llm import LLM
from vllm.entrypoints.openai.protocol import IOProcessorResponse from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM" MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
...@@ -23,61 +22,7 @@ def test_loading_missing_plugin(): ...@@ -23,61 +22,7 @@ def test_loading_missing_plugin():
get_io_processor(vllm_config, "wrong_plugin") get_io_processor(vllm_config, "wrong_plugin")
def test_loading_engine_with_wrong_plugin(): @pytest.fixture(scope="function")
with pytest.raises(ValueError):
LLM(
model=MODEL_NAME,
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=32,
io_processor_plugin="wrong_plugin",
)
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
img_prompt = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
pooling_params = PoolingParams(task="encode", softmax=False)
with vllm_runner(
model_name,
runner="pooling",
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=1,
io_processor_plugin="prithvi_to_tiff_valencia",
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(
img_prompt,
pooling_params=pooling_params,
)
output = pooler_output[0].outputs
# verify the output is formatted as expected for this plugin
assert all(
hasattr(output, attr)
for attr in ["type", "format", "data", "request_id"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(output.data)
@pytest.fixture(scope="module")
def server(): def server():
args = [ args = [
"--runner", "--runner",
...@@ -90,7 +35,9 @@ def server(): ...@@ -90,7 +35,9 @@ def server():
"--max-num-seqs", "--max-num-seqs",
"32", "32",
"--io-processor-plugin", "--io-processor-plugin",
"prithvi_to_tiff_valencia" "prithvi_to_tiff_valencia",
"--model-impl",
"terratorch",
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
...@@ -136,3 +83,43 @@ async def test_prithvi_mae_plugin_online( ...@@ -136,3 +83,43 @@ async def test_prithvi_mae_plugin_online(
# We just check that the output is a valid base64 string. # We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted. # Raises an exception and fails the test if the string is corrupted.
base64.b64decode(plugin_data["data"]) base64.b64decode(plugin_data["data"])
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
img_prompt = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
pooling_params = PoolingParams(task="encode", softmax=False)
with vllm_runner(
model_name,
runner="pooling",
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=1,
model_impl="terratorch",
io_processor_plugin="prithvi_to_tiff_valencia",
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(
img_prompt,
pooling_params=pooling_params,
)
output = pooler_output[0].outputs
# verify the output is formatted as expected for this plugin
assert all(
hasattr(output, attr)
for attr in ["type", "format", "data", "request_id"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(output.data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment