Unverified Commit 342a7cda authored by Christian Pinto's avatar Christian Pinto Committed by GitHub
Browse files

[Misc] Update tests and examples for Prithvi/Terratorch models (#34416)


Signed-off-by: default avatarChristian Pinto <christian.pinto@ibm.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent d1ea65d0
......@@ -28,7 +28,7 @@ def main():
)
llm = LLM(
model="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
model="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
......
......@@ -391,7 +391,7 @@ if __name__ == "__main__":
parser.add_argument(
"--model",
type=str,
default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
default="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
help="Path to a checkpoint file to load from.",
)
parser.add_argument(
......
......@@ -14,9 +14,7 @@ import requests
# - install TerraTorch v1.1 (or later):
# pip install terratorch>=v1.1
# - start vllm in serving mode with the below args
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
# --model-impl terratorch
# --trust-remote-code
# --model='ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11'
# --skip-tokenizer-init --enforce-eager
# --io-processor-plugin terratorch_segmentation
# --enable-mm-embeds
......@@ -34,7 +32,7 @@ def main():
"out_data_format": "b64_json",
},
"priority": 0,
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
"model": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
}
ret = requests.post(server_endpoint, json=request_payload_url)
......
......@@ -56,7 +56,14 @@ runai-model-streamer[s3,gcs]==0.15.3
fastsafetensors>=0.2.2 # 0.2.2 contains important fixes for multi-GPU mem usage
pydantic>=2.12 # 2.11 leads to error on python 3.13
decord==0.6.0
terratorch @ git+https://github.com/IBM/terratorch.git@1.1.rc3 # required for PrithviMAE test
terratorch >= 1.2.2 # Required for Prithvi tests
imagehash # Required for Prithvi tests
segmentation-models-pytorch > 0.4.0 # Required for Prithvi tests
gpt-oss >= 0.0.7; python_version > '3.11'
perceptron # required for isaac test
# Newer versions of datasets require torchcoded, that makes the tests fail in CI because of a missing library.
# Older versions are in conflict with teerratorch requirements.
datasets>=3.3.0,<=3.6.0
\ No newline at end of file
# This file was autogenerated by uv via the following command:
# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu129 --python-platform x86_64-manylinux_2_28 --python-version 3.12
absl-py==2.1.0
# via rouge-score
# via
# rouge-score
# tensorboard
accelerate==1.0.1
# via
# lm-eval
......@@ -31,9 +33,7 @@ albumentations==1.4.6
# -r requirements/test.in
# terratorch
alembic==1.16.4
# via
# mlflow
# optuna
# via optuna
annotated-doc==0.0.4
# via fastapi
annotated-types==0.7.0
......@@ -74,8 +74,6 @@ bitsandbytes==0.46.1
# lightning
black==24.10.0
# via datamodel-code-generator
blinker==1.9.0
# via flask
blobfile==3.0.0
# via -r requirements/test.in
bm25s==0.2.13
......@@ -93,9 +91,7 @@ bounded-pool-executor==0.0.3
buildkite-test-collector==0.1.9
# via -r requirements/test.in
cachetools==5.5.2
# via
# google-auth
# mlflow-skinny
# via google-auth
certifi==2024.8.30
# via
# fiona
......@@ -106,6 +102,7 @@ certifi==2024.8.30
# pyproj
# rasterio
# requests
# sentry-sdk
cffi==1.17.1
# via soundfile
chardet==5.2.0
......@@ -120,15 +117,14 @@ click==8.1.7
# click-plugins
# cligj
# fiona
# flask
# jiwer
# mlflow-skinny
# nltk
# rasterio
# ray
# schemathesis
# typer
# uvicorn
# wandb
click-plugins==1.1.1.2
# via
# fiona
......@@ -137,8 +133,6 @@ cligj==0.7.2
# via
# fiona
# rasterio
cloudpickle==3.1.1
# via mlflow-skinny
colorama==0.4.6
# via
# perceptron
......@@ -163,16 +157,15 @@ cupy-cuda12x==13.6.0
# via ray
cycler==0.12.1
# via matplotlib
databricks-sdk==0.59.0
# via mlflow-skinny
datamodel-code-generator==0.26.3
# via -r requirements/test.in
dataproperty==1.0.1
# via
# pytablewriter
# tabledata
datasets==3.0.2
datasets==3.3.0
# via
# -r requirements/test.in
# evaluate
# lm-eval
# mteb
......@@ -180,6 +173,8 @@ decorator==5.1.1
# via librosa
decord==0.6.0
# via -r requirements/test.in
diffusers==0.36.0
# via terratorch
dill==0.3.8
# via
# datasets
......@@ -191,15 +186,11 @@ distlib==0.3.9
dnspython==2.7.0
# via email-validator
docker==7.1.0
# via
# gpt-oss
# mlflow
# via gpt-oss
docopt==0.6.2
# via num2words
docstring-parser==0.17.0
# via jsonargparse
efficientnet-pytorch==0.7.1
# via segmentation-models-pytorch
einops==0.8.1
# via
# -r requirements/test.in
......@@ -217,9 +208,7 @@ encodec==0.1.1
evaluate==0.4.3
# via lm-eval
fastapi==0.128.0
# via
# gpt-oss
# mlflow-skinny
# via gpt-oss
fastparquet==2024.11.0
# via genai-perf
fastrlock==0.8.2
......@@ -230,6 +219,7 @@ filelock==3.16.1
# via
# blobfile
# datasets
# diffusers
# huggingface-hub
# ray
# torch
......@@ -237,8 +227,6 @@ filelock==3.16.1
# virtualenv
fiona==1.10.1
# via torchgeo
flask==3.1.1
# via mlflow
fonttools==4.55.0
# via matplotlib
fqdn==1.5.1
......@@ -249,7 +237,7 @@ frozenlist==1.5.0
# via
# aiohttp
# aiosignal
fsspec==2024.9.0
fsspec==2024.12.0
# via
# datasets
# evaluate
......@@ -257,6 +245,7 @@ fsspec==2024.9.0
# huggingface-hub
# lightning
# pytorch-lightning
# tacoreader
# torch
ftfy==6.3.1
# via open-clip-torch
......@@ -269,7 +258,7 @@ geopandas==1.0.1
gitdb==4.0.12
# via gitpython
gitpython==3.1.44
# via mlflow-skinny
# via wandb
google-api-core==2.24.2
# via
# google-cloud-core
......@@ -277,7 +266,6 @@ google-api-core==2.24.2
# opencensus
google-auth==2.40.2
# via
# databricks-sdk
# google-api-core
# google-cloud-core
# google-cloud-storage
......@@ -296,25 +284,17 @@ googleapis-common-protos==1.70.0
# via google-api-core
gpt-oss==0.0.8
# via -r requirements/test.in
graphene==3.4.3
# via mlflow
graphql-core==3.2.6
# via
# graphene
# graphql-relay
# hypothesis-graphql
graphql-relay==3.2.0
# via graphene
# via hypothesis-graphql
greenlet==3.2.3
# via sqlalchemy
grpcio==1.78.0
# via
# grpcio-tools
# ray
# tensorboard
grpcio-tools==1.78.0
# via -r requirements/test.in
gunicorn==23.0.0
# via mlflow
h11==0.14.0
# via
# httpcore
......@@ -338,12 +318,14 @@ httpcore==1.0.6
httpx==0.27.2
# via
# -r requirements/test.in
# diffusers
# perceptron
# schemathesis
huggingface-hub==0.36.2
# via
# accelerate
# datasets
# diffusers
# evaluate
# open-clip-torch
# peft
......@@ -379,11 +361,13 @@ idna==3.10
# jsonschema
# requests
# yarl
imagehash==4.3.2
# via -r requirements/test.in
imageio==2.37.0
# via scikit-image
importlib-metadata==8.7.0
# via
# mlflow-skinny
# diffusers
# opentelemetry-api
importlib-resources==6.5.2
# via typeshed-client
......@@ -395,14 +379,10 @@ isoduration==20.11.0
# via jsonschema
isort==5.13.2
# via datamodel-code-generator
itsdangerous==2.2.0
# via flask
jinja2==3.1.6
# via
# datamodel-code-generator
# flask
# genai-perf
# mlflow
# torch
jiwer==3.0.5
# via -r requirements/test.in
......@@ -415,12 +395,14 @@ joblib==1.4.2
# librosa
# nltk
# scikit-learn
jsonargparse==4.35.0
jsonargparse==4.46.0
# via
# lightning
# terratorch
jsonlines==4.0.0
# via lm-eval
jsonnet==0.21.0
# via jsonargparse
jsonpointer==3.0.0
# via jsonschema
jsonschema==4.23.0
......@@ -449,13 +431,13 @@ libnacl==2.1.0
# via tensorizer
librosa==0.10.2.post1
# via -r requirements/test.in
lightly==1.5.20
lightly==1.5.22
# via
# terratorch
# torchgeo
lightly-utils==0.0.2
# via lightly
lightning==2.5.1.post0
lightning==2.6.1
# via
# terratorch
# torchgeo
......@@ -476,12 +458,11 @@ lxml==5.3.0
mako==1.3.10
# via alembic
markdown==3.8.2
# via mlflow
# via tensorboard
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.1
# via
# flask
# jinja2
# mako
# werkzeug
......@@ -489,7 +470,6 @@ matplotlib==3.9.2
# via
# -r requirements/test.in
# lightning
# mlflow
# pycocotools
# torchgeo
mbstrdecoder==1.1.3
......@@ -501,10 +481,6 @@ mdurl==0.1.2
# via markdown-it-py
mistral-common==1.9.1
# via -r requirements/test.in
mlflow==2.22.0
# via terratorch
mlflow-skinny==2.22.0
# via mlflow
more-itertools==10.5.0
# via lm-eval
mpmath==1.3.0
......@@ -523,8 +499,6 @@ multiprocess==0.70.16
# via
# datasets
# evaluate
munch==4.0.0
# via pretrainedmodels
mypy-extensions==1.0.0
# via black
networkx==3.2.1
......@@ -553,6 +527,7 @@ numpy==2.2.6
# cupy-cuda12x
# datasets
# decord
# diffusers
# einx
# encodec
# evaluate
......@@ -560,13 +535,13 @@ numpy==2.2.6
# genai-perf
# geopandas
# h5py
# imagehash
# imageio
# librosa
# lightly
# lightly-utils
# matplotlib
# mistral-common
# mlflow
# mteb
# numba
# numexpr
......@@ -578,6 +553,7 @@ numpy==2.2.6
# perceptron
# pycocotools
# pyogrio
# pywavelets
# rasterio
# rioxarray
# rouge-score
......@@ -590,8 +566,10 @@ numpy==2.2.6
# shapely
# soxr
# statsmodels
# tensorboard
# tensorboardx
# tensorizer
# terratorch
# tifffile
# torchgeo
# torchmetrics
......@@ -659,7 +637,6 @@ opencv-python-headless==4.13.0.90
# mistral-common
opentelemetry-api==1.35.0
# via
# mlflow-skinny
# opentelemetry-exporter-prometheus
# opentelemetry-sdk
# opentelemetry-semantic-conventions
......@@ -669,7 +646,6 @@ opentelemetry-proto==1.36.0
# via ray
opentelemetry-sdk==1.35.0
# via
# mlflow-skinny
# opentelemetry-exporter-prometheus
# ray
opentelemetry-semantic-conventions==0.56b0
......@@ -687,7 +663,6 @@ packaging==24.2
# evaluate
# fastparquet
# geopandas
# gunicorn
# huggingface-hub
# hydra-core
# kornia
......@@ -695,7 +670,6 @@ packaging==24.2
# lightning
# lightning-utilities
# matplotlib
# mlflow-skinny
# optuna
# peft
# plotly
......@@ -708,10 +682,12 @@ packaging==24.2
# rioxarray
# scikit-image
# statsmodels
# tensorboard
# tensorboardx
# torchmetrics
# transformers
# typepy
# wandb
# xarray
pandas==2.2.3
# via
......@@ -720,8 +696,8 @@ pandas==2.2.3
# fastparquet
# genai-perf
# geopandas
# mlflow
# statsmodels
# tacoreader
# torchgeo
# xarray
pathspec==0.12.1
......@@ -740,7 +716,9 @@ perf-analyzer==0.1.0
# via genai-perf
pillow==10.4.0
# via
# diffusers
# genai-perf
# imagehash
# imageio
# lightly-utils
# matplotlib
......@@ -748,6 +726,7 @@ pillow==10.4.0
# perceptron
# scikit-image
# segmentation-models-pytorch
# tensorboard
# torchgeo
# torchvision
platformdirs==4.3.6
......@@ -755,6 +734,7 @@ platformdirs==4.3.6
# black
# pooch
# virtualenv
# wandb
plotly==5.24.1
# via genai-perf
pluggy==1.5.0
......@@ -769,8 +749,6 @@ portalocker==2.10.1
# via sacrebleu
pqdm==0.2.0
# via -r requirements/test.in
pretrainedmodels==0.7.4
# via segmentation-models-pytorch
prometheus-client==0.22.0
# via
# opentelemetry-exporter-prometheus
......@@ -786,12 +764,13 @@ protobuf==6.33.2
# google-api-core
# googleapis-common-protos
# grpcio-tools
# mlflow-skinny
# opentelemetry-proto
# proto-plus
# ray
# tensorboard
# tensorboardx
# tensorizer
# wandb
psutil==6.1.0
# via
# accelerate
......@@ -801,11 +780,12 @@ py==1.11.0
# via pytest-forked
py-spy==0.4.0
# via ray
pyarrow==18.0.0
pyarrow==23.0.0
# via
# datasets
# genai-perf
# mlflow
# tacoreader
# terratorch
pyasn1==0.6.1
# via
# pyasn1-modules
......@@ -831,11 +811,11 @@ pydantic==2.12.0
# gpt-oss
# lightly
# mistral-common
# mlflow-skinny
# mteb
# openai-harmony
# pydantic-extra-types
# ray
# wandb
pydantic-core==2.41.1
# via pydantic
pydantic-extra-types==2.10.5
......@@ -873,7 +853,6 @@ pytest==8.3.5
# pytest-subtests
# pytest-timeout
# schemathesis
# terratorch
pytest-asyncio==0.24.0
# via -r requirements/test.in
pytest-cov==6.3.0
......@@ -896,7 +875,6 @@ python-dateutil==2.9.0.post0
# via
# arrow
# botocore
# graphene
# lightly
# matplotlib
# pandas
......@@ -913,6 +891,8 @@ pytz==2024.2
# via
# pandas
# typepy
pywavelets==1.9.0
# via imagehash
pyyaml==6.0.2
# via
# accelerate
......@@ -923,7 +903,6 @@ pyyaml==6.0.2
# huggingface-hub
# jsonargparse
# lightning
# mlflow-skinny
# omegaconf
# optuna
# peft
......@@ -934,6 +913,7 @@ pyyaml==6.0.2
# timm
# transformers
# vocos
# wandb
rapidfuzz==3.12.1
# via jiwer
rasterio==1.4.3
......@@ -951,6 +931,7 @@ referencing==0.35.1
# jsonschema-specifications
regex==2024.9.11
# via
# diffusers
# nltk
# open-clip-torch
# sacrebleu
......@@ -959,8 +940,8 @@ regex==2024.9.11
requests==2.32.3
# via
# buildkite-test-collector
# databricks-sdk
# datasets
# diffusers
# docker
# evaluate
# google-api-core
......@@ -970,15 +951,16 @@ requests==2.32.3
# lightly
# lm-eval
# mistral-common
# mlflow-skinny
# mteb
# pooch
# ray
# responses
# schemathesis
# starlette-testclient
# tacoreader
# tiktoken
# transformers
# wandb
responses==0.25.3
# via genai-perf
rfc3339-validator==0.1.4
......@@ -991,6 +973,7 @@ rich==13.9.4
# lightning
# mteb
# perceptron
# terratorch
# typer
rioxarray==0.19.0
# via terratorch
......@@ -1017,47 +1000,55 @@ sacrebleu==2.4.3
safetensors==0.4.5
# via
# accelerate
# diffusers
# open-clip-torch
# peft
# segmentation-models-pytorch
# timm
# transformers
schemathesis==3.39.15
# via -r requirements/test.in
scikit-image==0.25.2
# via albumentations
# via
# albumentations
# terratorch
scikit-learn==1.5.2
# via
# albumentations
# librosa
# lm-eval
# mlflow
# mteb
# sentence-transformers
# terratorch
scipy==1.13.1
# via
# albumentations
# bm25s
# imagehash
# librosa
# mlflow
# mteb
# scikit-image
# scikit-learn
# sentence-transformers
# statsmodels
# vocos
segmentation-models-pytorch==0.4.0
segmentation-models-pytorch==0.5.0
# via
# -r requirements/test.in
# terratorch
# torchgeo
sentence-transformers==5.2.0
# via
# -r requirements/test.in
# mteb
sentry-sdk==2.52.0
# via wandb
setuptools==77.0.3
# via
# grpcio-tools
# lightning-utilities
# pytablewriter
# tensorboard
# torch
shapely==2.1.1
# via
......@@ -1075,7 +1066,6 @@ six==1.16.0
# python-dateutil
# rfc3339-validator
# rouge-score
# segmentation-models-pytorch
smart-open==7.1.0
# via ray
smmap==5.0.2
......@@ -1099,12 +1089,9 @@ soxr==0.5.0.post1
sqlalchemy==2.0.41
# via
# alembic
# mlflow
# optuna
sqlitedict==2.1.0
# via lm-eval
sqlparse==0.5.3
# via mlflow-skinny
starlette==0.50.0
# via
# fastapi
......@@ -1124,6 +1111,8 @@ tabledata==1.3.3
# via pytablewriter
tabulate==0.9.0
# via sacrebleu
tacoreader==0.5.6
# via terratorch
tblib==3.1.0
# via -r requirements/test.in
tcolorpy==0.1.6
......@@ -1133,13 +1122,19 @@ tenacity==9.1.2
# gpt-oss
# lm-eval
# plotly
tensorboard==2.20.0
# via terratorch
tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.4
# via lightning
tensorizer==2.10.1
# via -r requirements/test.in
termcolor==3.1.0
# via gpt-oss
terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e
# via
# gpt-oss
# terratorch
terratorch==1.2.2
# via -r requirements/test.in
threadpoolctl==3.5.0
# via scikit-learn
......@@ -1172,7 +1167,6 @@ torch==2.10.0+cu129
# -r requirements/test.in
# accelerate
# bitsandbytes
# efficientnet-pytorch
# encodec
# kornia
# lightly
......@@ -1181,7 +1175,6 @@ torch==2.10.0+cu129
# mteb
# open-clip-torch
# peft
# pretrainedmodels
# pytorch-lightning
# runai-model-streamer
# segmentation-models-pytorch
......@@ -1213,12 +1206,11 @@ torchvision==0.25.0+cu129
# -r requirements/test.in
# lightly
# open-clip-torch
# pretrainedmodels
# segmentation-models-pytorch
# terratorch
# timm
# torchgeo
tqdm==4.66.6
tqdm==4.67.3
# via
# datasets
# evaluate
......@@ -1232,10 +1224,11 @@ tqdm==4.66.6
# optuna
# peft
# pqdm
# pretrainedmodels
# pytorch-lightning
# segmentation-models-pytorch
# sentence-transformers
# tacoreader
# terratorch
# tqdm-multiprocess
# transformers
tqdm-multiprocess==0.0.11
......@@ -1274,14 +1267,12 @@ typing-extensions==4.15.0
# alembic
# chz
# fastapi
# graphene
# grpcio
# huggingface-hub
# librosa
# lightning
# lightning-utilities
# mistral-common
# mlflow-skinny
# mteb
# opentelemetry-api
# opentelemetry-sdk
......@@ -1299,6 +1290,7 @@ typing-extensions==4.15.0
# typer
# typeshed-client
# typing-inspection
# wandb
typing-inspection==0.4.2
# via pydantic
tzdata==2024.2
......@@ -1313,25 +1305,26 @@ urllib3==2.2.3
# lightly
# requests
# responses
# sentry-sdk
# tritonclient
uvicorn==0.35.0
# via
# gpt-oss
# mlflow-skinny
# via gpt-oss
vector-quantize-pytorch==1.21.2
# via -r requirements/test.in
virtualenv==20.31.2
# via ray
vocos==0.1.0
# via -r requirements/test.in
wandb==0.24.2
# via terratorch
wcwidth==0.2.13
# via ftfy
webcolors==24.11.1
# via jsonschema
werkzeug==3.1.3
# via
# flask
# schemathesis
# tensorboard
word2number==1.1
# via lm-eval
wrapt==1.17.2
......
......@@ -40,7 +40,7 @@ def _run_test(
vllm_model.llm.encode(prompt, pooling_task="plugin")
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
MODELS = ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
@pytest.mark.core_model
......
......@@ -13,7 +13,7 @@ from tests.utils import create_new_process_for_each_test
"model",
[
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
"mgazz/Prithvi_v2_eo_300_tl_unet_agb",
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars",
],
)
def test_inference(
......
......@@ -44,12 +44,8 @@ datamodule_config: DataModuleConfig = {
"no_label_replace": -1,
"num_workers": 8,
"test_transform": [
albumentations.Resize(
always_apply=False, height=448, interpolation=1, p=1, width=448
),
albumentations.pytorch.ToTensorV2(
transpose_mask=False, always_apply=True, p=1.0
),
albumentations.Resize(height=448, interpolation=1, p=1, width=448),
albumentations.pytorch.ToTensorV2(transpose_mask=False, p=1.0),
],
}
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import io
import imagehash
import pytest
import requests
from PIL import Image
from tests.utils import RemoteOpenAIServer
from vllm.config import VllmConfig
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
from vllm.plugins.io_processors import get_io_processor
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
models_config = {
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11": {
"image_url": "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff", # noqa: E501
"out_hash": "aa6d92ad25926a5e",
"plugin": "prithvi_to_tiff",
},
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars": {
"image_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars/resolve/main/examples/subsetted_512x512_HLS.S30.T10SEH.2018190.v1.4_merged.tif", # noqa: E501
"out_hash": "c07f4f602da73552",
"plugin": "prithvi_to_tiff",
},
}
def _compute_image_hash(base64_data: str) -> str:
# Decode the base64 output and create image from byte stream
decoded_image = base64.b64decode(base64_data)
image = Image.open(io.BytesIO(decoded_image))
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
# Compute perceptual hash of the output image
return str(imagehash.phash(image))
def test_loading_missing_plugin():
......@@ -22,33 +43,39 @@ def test_loading_missing_plugin():
@pytest.fixture(scope="function")
def server():
def server(model_name, plugin):
args = [
"--runner",
"pooling",
"--enforce-eager",
"--trust-remote-code",
"--skip-tokenizer-init",
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
"--max-num-seqs",
"32",
"--io-processor-plugin",
"prithvi_to_tiff",
"--model-impl",
"terratorch",
plugin,
"--enable-mm-embeds",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
with RemoteOpenAIServer(model_name, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize(
"model_name, image_url, plugin, expected_hash",
[
(model_name, config["image_url"], config["plugin"], config["out_hash"])
for model_name, config in models_config.items()
],
)
async def test_prithvi_mae_plugin_online(
server: RemoteOpenAIServer,
model_name: str,
image_url: str | dict,
plugin: str,
expected_hash: str,
):
request_payload_url = {
"data": {
......@@ -74,16 +101,25 @@ async def test_prithvi_mae_plugin_online(
# verify the output is formatted as expected for this plugin
plugin_data = parsed_response.data
assert all(plugin_data.get(attr) for attr in ["type", "format", "data"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(plugin_data["data"])
# Compute the output image hash and compare it against the expected hash
image_hash = _compute_image_hash(plugin_data["data"])
assert image_hash == expected_hash, (
f"Image hash mismatch: expected {expected_hash}, got {image_hash}"
)
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
@pytest.mark.parametrize(
"model_name, image_url, plugin, expected_hash",
[
(model_name, config["image_url"], config["plugin"], config["out_hash"])
for model_name, config in models_config.items()
],
)
def test_prithvi_mae_plugin_offline(
vllm_runner, model_name: str, image_url: str | dict, plugin: str, expected_hash: str
):
img_prompt = dict(
data=image_url,
data_format="url",
......@@ -96,13 +132,12 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
runner="pooling",
skip_tokenizer_init=True,
enable_mm_embeds=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=1,
model_impl="terratorch",
io_processor_plugin="prithvi_to_tiff",
max_num_seqs=32,
io_processor_plugin=plugin,
default_torch_num_threads=1,
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(img_prompt, pooling_task="plugin")
output = pooler_output[0].outputs
......@@ -110,6 +145,8 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
# verify the output is formatted as expected for this plugin
assert all(hasattr(output, attr) for attr in ["type", "format", "data"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(output.data)
# Compute the output image hash and compare it against the expected hash
image_hash = _compute_image_hash(output.data)
assert image_hash == expected_hash, (
f"Image hash mismatch: expected {expected_hash}, got {image_hash}"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment