Unverified Commit 51d5e9be authored by mgazz's avatar mgazz Committed by GitHub
Browse files

[Core][Model] Terratorch backend integration (#23513)


Signed-off-by: default avatarMichele Gazzetti <michele.gazzetti1@ibm.com>
Signed-off-by: default avatarChristian Pinto <christian.pinto@ibm.com>
Co-authored-by: default avatarChristian Pinto <christian.pinto@ibm.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent e7fc7001
......@@ -45,7 +45,11 @@ datamodule_config = {
class PrithviMAE:
def __init__(self, model):
self.model = LLM(
model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True
model=model,
skip_tokenizer_init=True,
dtype="float16",
enforce_eager=True,
model_impl="terratorch",
)
def run(self, input_data, location_coords):
......
......@@ -37,6 +37,7 @@ def main():
# The maximum number depends on the available GPU memory
max_num_seqs=32,
io_processor_plugin="prithvi_to_tiff_india",
model_impl="terratorch",
)
pooling_params = PoolingParams(task="encode", softmax=False)
......
......@@ -15,6 +15,7 @@ import requests
# https://github.com/christian-pinto/prithvi_io_processor_plugin
# - start vllm in serving mode with the below args
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
# --model-impl terratorch
# --task embed --trust-remote-code
# --skip-tokenizer-init --enforce-eager
# --io-processor-plugin prithvi_to_tiff_india
......
......@@ -53,5 +53,5 @@ runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10
terratorch==1.1rc2 # required for PrithviMAE test
decord==0.6.0
terratorch==1.1rc3 # required for PrithviMAE test
......@@ -1042,7 +1042,7 @@ tensorboardx==2.6.4
# via lightning
tensorizer==2.10.1
# via -r requirements/test.in
terratorch==1.1rc2
terratorch==1.1rc3
# via -r requirements/test.in
threadpoolctl==3.5.0
# via scikit-learn
......
......@@ -298,6 +298,8 @@ def _compare_tp(
tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides
hf_config = get_config(model_id, trust_remote_code)
skip_tokenizer_init = model_info.skip_tokenizer_init
max_num_seqs = model_info.max_num_seqs
dtype = "float16"
if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS:
......@@ -351,6 +353,10 @@ def _compare_tp(
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if skip_tokenizer_init:
common_args.append("--skip-tokenizer-init")
if max_num_seqs:
common_args.extend(["--max-num-seqs", f"{max_num_seqs}"])
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
testing_ray_compiled_graph = False
......
......@@ -178,6 +178,7 @@ def _compare_sp(
trust_remote_code = model_info.trust_remote_code
tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides
skip_tokenizer_init = model_info.skip_tokenizer_init
if load_format == "dummy":
# Avoid OOM
......@@ -227,6 +228,8 @@ def _compare_sp(
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if skip_tokenizer_init:
common_args.append("--skip-tokenizer-init")
compilation_config = {
'level': 3,
......
......@@ -104,7 +104,9 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
trust_remote_code=model_info.trust_remote_code,
revision=model_info.revision,
hf_overrides=model_info.hf_overrides,
)
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
# Initialize the tokenizer
tokenizer = get_tokenizer(
......
......@@ -11,7 +11,7 @@ import torch
from ...utils import RemoteOpenAIServer
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
DTYPE = "float16"
......@@ -35,7 +35,9 @@ def server():
"--trust-remote-code",
"--skip-tokenizer-init",
"--max-num-seqs",
"32"
"32",
"--model-impl",
"terratorch"
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
......
......@@ -1266,7 +1266,9 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup(
......@@ -1322,7 +1324,9 @@ def test_resolve_content_format_hf_defined(model, expected_format):
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
tokenizer_group = TokenizerGroup(
model,
......@@ -1382,7 +1386,9 @@ def test_resolve_content_format_fallbacks(model, expected_format):
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
tokenizer_group = TokenizerGroup(
model_config.tokenizer,
......
......@@ -69,6 +69,9 @@ def run_test(
vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode
if model_info.hf_overrides:
vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides
if model_info.skip_tokenizer_init:
vllm_runner_kwargs_[
"skip_tokenizer_init"] = model_info.skip_tokenizer_init
if vllm_runner_kwargs:
vllm_runner_kwargs_.update(vllm_runner_kwargs)
......
......@@ -46,7 +46,7 @@ def _run_test(
vllm_model.encode(prompt)
MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
@pytest.mark.core_model
......
......@@ -66,7 +66,9 @@ def _test_processing_correctness(
hf_overrides=model_info.hf_overrides,
# Ensure that the cache can fit all of the data
mm_processor_cache_gb=2048,
)
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
......
......@@ -196,7 +196,9 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=hf_overrides_fn,
)
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
......
......@@ -59,7 +59,9 @@ def test_hf_model_weights_mapper(model_arch: str):
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
original_weights = create_repo_dummy_weights(model_id)
......
......@@ -6,10 +6,11 @@ from dataclasses import dataclass, field
from typing import Any, Literal, Optional
import pytest
import torch
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.config import TokenizerMode
from vllm.config import ModelDType, TokenizerMode
@dataclass(frozen=True)
......@@ -47,6 +48,23 @@ class _HfExamplesInfo:
The reason for the minimum/maximum version requirement.
"""
skip_tokenizer_init: bool = False
"""
If true, skip initialization of tokenizer and detokenizer.
"""
dtype: ModelDType = "auto"
"""
The data type for the model weights and activations.
"""
enforce_eager: bool = False
"""
Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
"""
is_available_online: bool = True
"""
Set this to ``False`` if the name of this architecture no longer exists on
......@@ -76,6 +94,9 @@ class _HfExamplesInfo:
If not specified, the default revision will be used.
"""
max_num_seqs: Optional[int] = None
"""Maximum number of sequences to be processed in a single iteration."""
def check_transformers_version(
self,
*,
......@@ -361,8 +382,21 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
trust_remote_code=True),
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
is_available_online=False), # noqa: E501
"PrithviGeoSpatialMAE": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# This is to avoid the model
# going OOM in CI
max_num_seqs=32,
),
"Terratorch": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# This is to avoid the model going OOM in CI
max_num_seqs=32,
),
}
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
......
......@@ -73,6 +73,9 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
tokenizer=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
enforce_eager=model_info.enforce_eager,
skip_tokenizer_init=model_info.skip_tokenizer_init,
dtype=model_info.dtype,
speculative_config={
"model": model_info.speculative_model,
"num_speculative_tokens": 1,
......@@ -85,7 +88,7 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
model_impl=ModelImpl.TRANSFORMERS
if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM,
hf_overrides=hf_overrides_fn,
)
max_num_seqs=model_info.max_num_seqs)
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.conftest import VllmRunner
from vllm.utils import set_default_torch_num_threads
@pytest.mark.parametrize(
"model",
[
"mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
"mgazz/Prithvi_v2_eo_300_tl_unet_agb"
],
)
def test_inference(
vllm_runner: type[VllmRunner],
model: str,
) -> None:
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
prompt = dict(prompt_token_ids=[1],
multi_modal_data=dict(pixel_values=pixel_values,
location_coords=location_coords))
with (
set_default_torch_num_threads(1),
vllm_runner(
model,
runner="pooling",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# Limit the maximum number of sequences to avoid the
# test going OOM during the warmup run
max_num_seqs=32,
) as vllm_model,
):
vllm_output = vllm_model.llm.encode(prompt)
assert torch.equal(
torch.isnan(vllm_output[0].outputs.data).any(),
torch.tensor(False))
......@@ -294,6 +294,8 @@ def build_model_context(
limit_mm_per_prompt=limit_mm_per_prompt,
mm_processor_cache_gb=mm_processor_cache_gb,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
**model_config_kwargs,
)
return InputContext(model_config)
......
......@@ -7,12 +7,11 @@ import requests
from tests.utils import RemoteOpenAIServer
from vllm.config import VllmConfig
from vllm.entrypoints.llm import LLM
from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
......@@ -23,61 +22,7 @@ def test_loading_missing_plugin():
get_io_processor(vllm_config, "wrong_plugin")
def test_loading_engine_with_wrong_plugin():
with pytest.raises(ValueError):
LLM(
model=MODEL_NAME,
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=32,
io_processor_plugin="wrong_plugin",
)
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
img_prompt = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
pooling_params = PoolingParams(task="encode", softmax=False)
with vllm_runner(
model_name,
runner="pooling",
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=1,
io_processor_plugin="prithvi_to_tiff_valencia",
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(
img_prompt,
pooling_params=pooling_params,
)
output = pooler_output[0].outputs
# verify the output is formatted as expected for this plugin
assert all(
hasattr(output, attr)
for attr in ["type", "format", "data", "request_id"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(output.data)
@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def server():
args = [
"--runner",
......@@ -90,7 +35,9 @@ def server():
"--max-num-seqs",
"32",
"--io-processor-plugin",
"prithvi_to_tiff_valencia"
"prithvi_to_tiff_valencia",
"--model-impl",
"terratorch",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
......@@ -136,3 +83,43 @@ async def test_prithvi_mae_plugin_online(
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(plugin_data["data"])
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
img_prompt = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
pooling_params = PoolingParams(task="encode", softmax=False)
with vllm_runner(
model_name,
runner="pooling",
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=1,
model_impl="terratorch",
io_processor_plugin="prithvi_to_tiff_valencia",
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(
img_prompt,
pooling_params=pooling_params,
)
output = pooler_output[0].outputs
# verify the output is formatted as expected for this plugin
assert all(
hasattr(output, attr)
for attr in ["type", "format", "data", "request_id"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(output.data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment