Unverified Commit 342a7cda authored by Christian Pinto's avatar Christian Pinto Committed by GitHub
Browse files

[Misc] Update tests and examples for Prithvi/Terratorch models (#34416)


Signed-off-by: default avatarChristian Pinto <christian.pinto@ibm.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent d1ea65d0
...@@ -28,7 +28,7 @@ def main(): ...@@ -28,7 +28,7 @@ def main():
) )
llm = LLM( llm = LLM(
model="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", model="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
skip_tokenizer_init=True, skip_tokenizer_init=True,
trust_remote_code=True, trust_remote_code=True,
enforce_eager=True, enforce_eager=True,
......
...@@ -391,7 +391,7 @@ if __name__ == "__main__": ...@@ -391,7 +391,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,
default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", default="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
help="Path to a checkpoint file to load from.", help="Path to a checkpoint file to load from.",
) )
parser.add_argument( parser.add_argument(
......
...@@ -14,9 +14,7 @@ import requests ...@@ -14,9 +14,7 @@ import requests
# - install TerraTorch v1.1 (or later): # - install TerraTorch v1.1 (or later):
# pip install terratorch>=v1.1 # pip install terratorch>=v1.1
# - start vllm in serving mode with the below args # - start vllm in serving mode with the below args
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM' # --model='ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11'
# --model-impl terratorch
# --trust-remote-code
# --skip-tokenizer-init --enforce-eager # --skip-tokenizer-init --enforce-eager
# --io-processor-plugin terratorch_segmentation # --io-processor-plugin terratorch_segmentation
# --enable-mm-embeds # --enable-mm-embeds
...@@ -34,7 +32,7 @@ def main(): ...@@ -34,7 +32,7 @@ def main():
"out_data_format": "b64_json", "out_data_format": "b64_json",
}, },
"priority": 0, "priority": 0,
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", "model": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
} }
ret = requests.post(server_endpoint, json=request_payload_url) ret = requests.post(server_endpoint, json=request_payload_url)
......
...@@ -56,7 +56,14 @@ runai-model-streamer[s3,gcs]==0.15.3 ...@@ -56,7 +56,14 @@ runai-model-streamer[s3,gcs]==0.15.3
fastsafetensors>=0.2.2 # 0.2.2 contains important fixes for multi-GPU mem usage fastsafetensors>=0.2.2 # 0.2.2 contains important fixes for multi-GPU mem usage
pydantic>=2.12 # 2.11 leads to error on python 3.13 pydantic>=2.12 # 2.11 leads to error on python 3.13
decord==0.6.0 decord==0.6.0
terratorch @ git+https://github.com/IBM/terratorch.git@1.1.rc3 # required for PrithviMAE test terratorch >= 1.2.2 # Required for Prithvi tests
imagehash # Required for Prithvi tests
segmentation-models-pytorch > 0.4.0 # Required for Prithvi tests
gpt-oss >= 0.0.7; python_version > '3.11' gpt-oss >= 0.0.7; python_version > '3.11'
perceptron # required for isaac test perceptron # required for isaac test
# Newer versions of datasets require torchcoded, that makes the tests fail in CI because of a missing library.
# Older versions are in conflict with teerratorch requirements.
datasets>=3.3.0,<=3.6.0
\ No newline at end of file
# This file was autogenerated by uv via the following command: # This file was autogenerated by uv via the following command:
# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu129 --python-platform x86_64-manylinux_2_28 --python-version 3.12 # uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu129 --python-platform x86_64-manylinux_2_28 --python-version 3.12
absl-py==2.1.0 absl-py==2.1.0
# via rouge-score # via
# rouge-score
# tensorboard
accelerate==1.0.1 accelerate==1.0.1
# via # via
# lm-eval # lm-eval
...@@ -31,9 +33,7 @@ albumentations==1.4.6 ...@@ -31,9 +33,7 @@ albumentations==1.4.6
# -r requirements/test.in # -r requirements/test.in
# terratorch # terratorch
alembic==1.16.4 alembic==1.16.4
# via # via optuna
# mlflow
# optuna
annotated-doc==0.0.4 annotated-doc==0.0.4
# via fastapi # via fastapi
annotated-types==0.7.0 annotated-types==0.7.0
...@@ -74,8 +74,6 @@ bitsandbytes==0.46.1 ...@@ -74,8 +74,6 @@ bitsandbytes==0.46.1
# lightning # lightning
black==24.10.0 black==24.10.0
# via datamodel-code-generator # via datamodel-code-generator
blinker==1.9.0
# via flask
blobfile==3.0.0 blobfile==3.0.0
# via -r requirements/test.in # via -r requirements/test.in
bm25s==0.2.13 bm25s==0.2.13
...@@ -93,9 +91,7 @@ bounded-pool-executor==0.0.3 ...@@ -93,9 +91,7 @@ bounded-pool-executor==0.0.3
buildkite-test-collector==0.1.9 buildkite-test-collector==0.1.9
# via -r requirements/test.in # via -r requirements/test.in
cachetools==5.5.2 cachetools==5.5.2
# via # via google-auth
# google-auth
# mlflow-skinny
certifi==2024.8.30 certifi==2024.8.30
# via # via
# fiona # fiona
...@@ -106,6 +102,7 @@ certifi==2024.8.30 ...@@ -106,6 +102,7 @@ certifi==2024.8.30
# pyproj # pyproj
# rasterio # rasterio
# requests # requests
# sentry-sdk
cffi==1.17.1 cffi==1.17.1
# via soundfile # via soundfile
chardet==5.2.0 chardet==5.2.0
...@@ -120,15 +117,14 @@ click==8.1.7 ...@@ -120,15 +117,14 @@ click==8.1.7
# click-plugins # click-plugins
# cligj # cligj
# fiona # fiona
# flask
# jiwer # jiwer
# mlflow-skinny
# nltk # nltk
# rasterio # rasterio
# ray # ray
# schemathesis # schemathesis
# typer # typer
# uvicorn # uvicorn
# wandb
click-plugins==1.1.1.2 click-plugins==1.1.1.2
# via # via
# fiona # fiona
...@@ -137,8 +133,6 @@ cligj==0.7.2 ...@@ -137,8 +133,6 @@ cligj==0.7.2
# via # via
# fiona # fiona
# rasterio # rasterio
cloudpickle==3.1.1
# via mlflow-skinny
colorama==0.4.6 colorama==0.4.6
# via # via
# perceptron # perceptron
...@@ -163,16 +157,15 @@ cupy-cuda12x==13.6.0 ...@@ -163,16 +157,15 @@ cupy-cuda12x==13.6.0
# via ray # via ray
cycler==0.12.1 cycler==0.12.1
# via matplotlib # via matplotlib
databricks-sdk==0.59.0
# via mlflow-skinny
datamodel-code-generator==0.26.3 datamodel-code-generator==0.26.3
# via -r requirements/test.in # via -r requirements/test.in
dataproperty==1.0.1 dataproperty==1.0.1
# via # via
# pytablewriter # pytablewriter
# tabledata # tabledata
datasets==3.0.2 datasets==3.3.0
# via # via
# -r requirements/test.in
# evaluate # evaluate
# lm-eval # lm-eval
# mteb # mteb
...@@ -180,6 +173,8 @@ decorator==5.1.1 ...@@ -180,6 +173,8 @@ decorator==5.1.1
# via librosa # via librosa
decord==0.6.0 decord==0.6.0
# via -r requirements/test.in # via -r requirements/test.in
diffusers==0.36.0
# via terratorch
dill==0.3.8 dill==0.3.8
# via # via
# datasets # datasets
...@@ -191,15 +186,11 @@ distlib==0.3.9 ...@@ -191,15 +186,11 @@ distlib==0.3.9
dnspython==2.7.0 dnspython==2.7.0
# via email-validator # via email-validator
docker==7.1.0 docker==7.1.0
# via # via gpt-oss
# gpt-oss
# mlflow
docopt==0.6.2 docopt==0.6.2
# via num2words # via num2words
docstring-parser==0.17.0 docstring-parser==0.17.0
# via jsonargparse # via jsonargparse
efficientnet-pytorch==0.7.1
# via segmentation-models-pytorch
einops==0.8.1 einops==0.8.1
# via # via
# -r requirements/test.in # -r requirements/test.in
...@@ -217,9 +208,7 @@ encodec==0.1.1 ...@@ -217,9 +208,7 @@ encodec==0.1.1
evaluate==0.4.3 evaluate==0.4.3
# via lm-eval # via lm-eval
fastapi==0.128.0 fastapi==0.128.0
# via # via gpt-oss
# gpt-oss
# mlflow-skinny
fastparquet==2024.11.0 fastparquet==2024.11.0
# via genai-perf # via genai-perf
fastrlock==0.8.2 fastrlock==0.8.2
...@@ -230,6 +219,7 @@ filelock==3.16.1 ...@@ -230,6 +219,7 @@ filelock==3.16.1
# via # via
# blobfile # blobfile
# datasets # datasets
# diffusers
# huggingface-hub # huggingface-hub
# ray # ray
# torch # torch
...@@ -237,8 +227,6 @@ filelock==3.16.1 ...@@ -237,8 +227,6 @@ filelock==3.16.1
# virtualenv # virtualenv
fiona==1.10.1 fiona==1.10.1
# via torchgeo # via torchgeo
flask==3.1.1
# via mlflow
fonttools==4.55.0 fonttools==4.55.0
# via matplotlib # via matplotlib
fqdn==1.5.1 fqdn==1.5.1
...@@ -249,7 +237,7 @@ frozenlist==1.5.0 ...@@ -249,7 +237,7 @@ frozenlist==1.5.0
# via # via
# aiohttp # aiohttp
# aiosignal # aiosignal
fsspec==2024.9.0 fsspec==2024.12.0
# via # via
# datasets # datasets
# evaluate # evaluate
...@@ -257,6 +245,7 @@ fsspec==2024.9.0 ...@@ -257,6 +245,7 @@ fsspec==2024.9.0
# huggingface-hub # huggingface-hub
# lightning # lightning
# pytorch-lightning # pytorch-lightning
# tacoreader
# torch # torch
ftfy==6.3.1 ftfy==6.3.1
# via open-clip-torch # via open-clip-torch
...@@ -269,7 +258,7 @@ geopandas==1.0.1 ...@@ -269,7 +258,7 @@ geopandas==1.0.1
gitdb==4.0.12 gitdb==4.0.12
# via gitpython # via gitpython
gitpython==3.1.44 gitpython==3.1.44
# via mlflow-skinny # via wandb
google-api-core==2.24.2 google-api-core==2.24.2
# via # via
# google-cloud-core # google-cloud-core
...@@ -277,7 +266,6 @@ google-api-core==2.24.2 ...@@ -277,7 +266,6 @@ google-api-core==2.24.2
# opencensus # opencensus
google-auth==2.40.2 google-auth==2.40.2
# via # via
# databricks-sdk
# google-api-core # google-api-core
# google-cloud-core # google-cloud-core
# google-cloud-storage # google-cloud-storage
...@@ -296,25 +284,17 @@ googleapis-common-protos==1.70.0 ...@@ -296,25 +284,17 @@ googleapis-common-protos==1.70.0
# via google-api-core # via google-api-core
gpt-oss==0.0.8 gpt-oss==0.0.8
# via -r requirements/test.in # via -r requirements/test.in
graphene==3.4.3
# via mlflow
graphql-core==3.2.6 graphql-core==3.2.6
# via # via hypothesis-graphql
# graphene
# graphql-relay
# hypothesis-graphql
graphql-relay==3.2.0
# via graphene
greenlet==3.2.3 greenlet==3.2.3
# via sqlalchemy # via sqlalchemy
grpcio==1.78.0 grpcio==1.78.0
# via # via
# grpcio-tools # grpcio-tools
# ray # ray
# tensorboard
grpcio-tools==1.78.0 grpcio-tools==1.78.0
# via -r requirements/test.in # via -r requirements/test.in
gunicorn==23.0.0
# via mlflow
h11==0.14.0 h11==0.14.0
# via # via
# httpcore # httpcore
...@@ -338,12 +318,14 @@ httpcore==1.0.6 ...@@ -338,12 +318,14 @@ httpcore==1.0.6
httpx==0.27.2 httpx==0.27.2
# via # via
# -r requirements/test.in # -r requirements/test.in
# diffusers
# perceptron # perceptron
# schemathesis # schemathesis
huggingface-hub==0.36.2 huggingface-hub==0.36.2
# via # via
# accelerate # accelerate
# datasets # datasets
# diffusers
# evaluate # evaluate
# open-clip-torch # open-clip-torch
# peft # peft
...@@ -379,11 +361,13 @@ idna==3.10 ...@@ -379,11 +361,13 @@ idna==3.10
# jsonschema # jsonschema
# requests # requests
# yarl # yarl
imagehash==4.3.2
# via -r requirements/test.in
imageio==2.37.0 imageio==2.37.0
# via scikit-image # via scikit-image
importlib-metadata==8.7.0 importlib-metadata==8.7.0
# via # via
# mlflow-skinny # diffusers
# opentelemetry-api # opentelemetry-api
importlib-resources==6.5.2 importlib-resources==6.5.2
# via typeshed-client # via typeshed-client
...@@ -395,14 +379,10 @@ isoduration==20.11.0 ...@@ -395,14 +379,10 @@ isoduration==20.11.0
# via jsonschema # via jsonschema
isort==5.13.2 isort==5.13.2
# via datamodel-code-generator # via datamodel-code-generator
itsdangerous==2.2.0
# via flask
jinja2==3.1.6 jinja2==3.1.6
# via # via
# datamodel-code-generator # datamodel-code-generator
# flask
# genai-perf # genai-perf
# mlflow
# torch # torch
jiwer==3.0.5 jiwer==3.0.5
# via -r requirements/test.in # via -r requirements/test.in
...@@ -415,12 +395,14 @@ joblib==1.4.2 ...@@ -415,12 +395,14 @@ joblib==1.4.2
# librosa # librosa
# nltk # nltk
# scikit-learn # scikit-learn
jsonargparse==4.35.0 jsonargparse==4.46.0
# via # via
# lightning # lightning
# terratorch # terratorch
jsonlines==4.0.0 jsonlines==4.0.0
# via lm-eval # via lm-eval
jsonnet==0.21.0
# via jsonargparse
jsonpointer==3.0.0 jsonpointer==3.0.0
# via jsonschema # via jsonschema
jsonschema==4.23.0 jsonschema==4.23.0
...@@ -449,13 +431,13 @@ libnacl==2.1.0 ...@@ -449,13 +431,13 @@ libnacl==2.1.0
# via tensorizer # via tensorizer
librosa==0.10.2.post1 librosa==0.10.2.post1
# via -r requirements/test.in # via -r requirements/test.in
lightly==1.5.20 lightly==1.5.22
# via # via
# terratorch # terratorch
# torchgeo # torchgeo
lightly-utils==0.0.2 lightly-utils==0.0.2
# via lightly # via lightly
lightning==2.5.1.post0 lightning==2.6.1
# via # via
# terratorch # terratorch
# torchgeo # torchgeo
...@@ -476,12 +458,11 @@ lxml==5.3.0 ...@@ -476,12 +458,11 @@ lxml==5.3.0
mako==1.3.10 mako==1.3.10
# via alembic # via alembic
markdown==3.8.2 markdown==3.8.2
# via mlflow # via tensorboard
markdown-it-py==3.0.0 markdown-it-py==3.0.0
# via rich # via rich
markupsafe==3.0.1 markupsafe==3.0.1
# via # via
# flask
# jinja2 # jinja2
# mako # mako
# werkzeug # werkzeug
...@@ -489,7 +470,6 @@ matplotlib==3.9.2 ...@@ -489,7 +470,6 @@ matplotlib==3.9.2
# via # via
# -r requirements/test.in # -r requirements/test.in
# lightning # lightning
# mlflow
# pycocotools # pycocotools
# torchgeo # torchgeo
mbstrdecoder==1.1.3 mbstrdecoder==1.1.3
...@@ -501,10 +481,6 @@ mdurl==0.1.2 ...@@ -501,10 +481,6 @@ mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
mistral-common==1.9.1 mistral-common==1.9.1
# via -r requirements/test.in # via -r requirements/test.in
mlflow==2.22.0
# via terratorch
mlflow-skinny==2.22.0
# via mlflow
more-itertools==10.5.0 more-itertools==10.5.0
# via lm-eval # via lm-eval
mpmath==1.3.0 mpmath==1.3.0
...@@ -523,8 +499,6 @@ multiprocess==0.70.16 ...@@ -523,8 +499,6 @@ multiprocess==0.70.16
# via # via
# datasets # datasets
# evaluate # evaluate
munch==4.0.0
# via pretrainedmodels
mypy-extensions==1.0.0 mypy-extensions==1.0.0
# via black # via black
networkx==3.2.1 networkx==3.2.1
...@@ -553,6 +527,7 @@ numpy==2.2.6 ...@@ -553,6 +527,7 @@ numpy==2.2.6
# cupy-cuda12x # cupy-cuda12x
# datasets # datasets
# decord # decord
# diffusers
# einx # einx
# encodec # encodec
# evaluate # evaluate
...@@ -560,13 +535,13 @@ numpy==2.2.6 ...@@ -560,13 +535,13 @@ numpy==2.2.6
# genai-perf # genai-perf
# geopandas # geopandas
# h5py # h5py
# imagehash
# imageio # imageio
# librosa # librosa
# lightly # lightly
# lightly-utils # lightly-utils
# matplotlib # matplotlib
# mistral-common # mistral-common
# mlflow
# mteb # mteb
# numba # numba
# numexpr # numexpr
...@@ -578,6 +553,7 @@ numpy==2.2.6 ...@@ -578,6 +553,7 @@ numpy==2.2.6
# perceptron # perceptron
# pycocotools # pycocotools
# pyogrio # pyogrio
# pywavelets
# rasterio # rasterio
# rioxarray # rioxarray
# rouge-score # rouge-score
...@@ -590,8 +566,10 @@ numpy==2.2.6 ...@@ -590,8 +566,10 @@ numpy==2.2.6
# shapely # shapely
# soxr # soxr
# statsmodels # statsmodels
# tensorboard
# tensorboardx # tensorboardx
# tensorizer # tensorizer
# terratorch
# tifffile # tifffile
# torchgeo # torchgeo
# torchmetrics # torchmetrics
...@@ -659,7 +637,6 @@ opencv-python-headless==4.13.0.90 ...@@ -659,7 +637,6 @@ opencv-python-headless==4.13.0.90
# mistral-common # mistral-common
opentelemetry-api==1.35.0 opentelemetry-api==1.35.0
# via # via
# mlflow-skinny
# opentelemetry-exporter-prometheus # opentelemetry-exporter-prometheus
# opentelemetry-sdk # opentelemetry-sdk
# opentelemetry-semantic-conventions # opentelemetry-semantic-conventions
...@@ -669,7 +646,6 @@ opentelemetry-proto==1.36.0 ...@@ -669,7 +646,6 @@ opentelemetry-proto==1.36.0
# via ray # via ray
opentelemetry-sdk==1.35.0 opentelemetry-sdk==1.35.0
# via # via
# mlflow-skinny
# opentelemetry-exporter-prometheus # opentelemetry-exporter-prometheus
# ray # ray
opentelemetry-semantic-conventions==0.56b0 opentelemetry-semantic-conventions==0.56b0
...@@ -687,7 +663,6 @@ packaging==24.2 ...@@ -687,7 +663,6 @@ packaging==24.2
# evaluate # evaluate
# fastparquet # fastparquet
# geopandas # geopandas
# gunicorn
# huggingface-hub # huggingface-hub
# hydra-core # hydra-core
# kornia # kornia
...@@ -695,7 +670,6 @@ packaging==24.2 ...@@ -695,7 +670,6 @@ packaging==24.2
# lightning # lightning
# lightning-utilities # lightning-utilities
# matplotlib # matplotlib
# mlflow-skinny
# optuna # optuna
# peft # peft
# plotly # plotly
...@@ -708,10 +682,12 @@ packaging==24.2 ...@@ -708,10 +682,12 @@ packaging==24.2
# rioxarray # rioxarray
# scikit-image # scikit-image
# statsmodels # statsmodels
# tensorboard
# tensorboardx # tensorboardx
# torchmetrics # torchmetrics
# transformers # transformers
# typepy # typepy
# wandb
# xarray # xarray
pandas==2.2.3 pandas==2.2.3
# via # via
...@@ -720,8 +696,8 @@ pandas==2.2.3 ...@@ -720,8 +696,8 @@ pandas==2.2.3
# fastparquet # fastparquet
# genai-perf # genai-perf
# geopandas # geopandas
# mlflow
# statsmodels # statsmodels
# tacoreader
# torchgeo # torchgeo
# xarray # xarray
pathspec==0.12.1 pathspec==0.12.1
...@@ -740,7 +716,9 @@ perf-analyzer==0.1.0 ...@@ -740,7 +716,9 @@ perf-analyzer==0.1.0
# via genai-perf # via genai-perf
pillow==10.4.0 pillow==10.4.0
# via # via
# diffusers
# genai-perf # genai-perf
# imagehash
# imageio # imageio
# lightly-utils # lightly-utils
# matplotlib # matplotlib
...@@ -748,6 +726,7 @@ pillow==10.4.0 ...@@ -748,6 +726,7 @@ pillow==10.4.0
# perceptron # perceptron
# scikit-image # scikit-image
# segmentation-models-pytorch # segmentation-models-pytorch
# tensorboard
# torchgeo # torchgeo
# torchvision # torchvision
platformdirs==4.3.6 platformdirs==4.3.6
...@@ -755,6 +734,7 @@ platformdirs==4.3.6 ...@@ -755,6 +734,7 @@ platformdirs==4.3.6
# black # black
# pooch # pooch
# virtualenv # virtualenv
# wandb
plotly==5.24.1 plotly==5.24.1
# via genai-perf # via genai-perf
pluggy==1.5.0 pluggy==1.5.0
...@@ -769,8 +749,6 @@ portalocker==2.10.1 ...@@ -769,8 +749,6 @@ portalocker==2.10.1
# via sacrebleu # via sacrebleu
pqdm==0.2.0 pqdm==0.2.0
# via -r requirements/test.in # via -r requirements/test.in
pretrainedmodels==0.7.4
# via segmentation-models-pytorch
prometheus-client==0.22.0 prometheus-client==0.22.0
# via # via
# opentelemetry-exporter-prometheus # opentelemetry-exporter-prometheus
...@@ -786,12 +764,13 @@ protobuf==6.33.2 ...@@ -786,12 +764,13 @@ protobuf==6.33.2
# google-api-core # google-api-core
# googleapis-common-protos # googleapis-common-protos
# grpcio-tools # grpcio-tools
# mlflow-skinny
# opentelemetry-proto # opentelemetry-proto
# proto-plus # proto-plus
# ray # ray
# tensorboard
# tensorboardx # tensorboardx
# tensorizer # tensorizer
# wandb
psutil==6.1.0 psutil==6.1.0
# via # via
# accelerate # accelerate
...@@ -801,11 +780,12 @@ py==1.11.0 ...@@ -801,11 +780,12 @@ py==1.11.0
# via pytest-forked # via pytest-forked
py-spy==0.4.0 py-spy==0.4.0
# via ray # via ray
pyarrow==18.0.0 pyarrow==23.0.0
# via # via
# datasets # datasets
# genai-perf # genai-perf
# mlflow # tacoreader
# terratorch
pyasn1==0.6.1 pyasn1==0.6.1
# via # via
# pyasn1-modules # pyasn1-modules
...@@ -831,11 +811,11 @@ pydantic==2.12.0 ...@@ -831,11 +811,11 @@ pydantic==2.12.0
# gpt-oss # gpt-oss
# lightly # lightly
# mistral-common # mistral-common
# mlflow-skinny
# mteb # mteb
# openai-harmony # openai-harmony
# pydantic-extra-types # pydantic-extra-types
# ray # ray
# wandb
pydantic-core==2.41.1 pydantic-core==2.41.1
# via pydantic # via pydantic
pydantic-extra-types==2.10.5 pydantic-extra-types==2.10.5
...@@ -873,7 +853,6 @@ pytest==8.3.5 ...@@ -873,7 +853,6 @@ pytest==8.3.5
# pytest-subtests # pytest-subtests
# pytest-timeout # pytest-timeout
# schemathesis # schemathesis
# terratorch
pytest-asyncio==0.24.0 pytest-asyncio==0.24.0
# via -r requirements/test.in # via -r requirements/test.in
pytest-cov==6.3.0 pytest-cov==6.3.0
...@@ -896,7 +875,6 @@ python-dateutil==2.9.0.post0 ...@@ -896,7 +875,6 @@ python-dateutil==2.9.0.post0
# via # via
# arrow # arrow
# botocore # botocore
# graphene
# lightly # lightly
# matplotlib # matplotlib
# pandas # pandas
...@@ -913,6 +891,8 @@ pytz==2024.2 ...@@ -913,6 +891,8 @@ pytz==2024.2
# via # via
# pandas # pandas
# typepy # typepy
pywavelets==1.9.0
# via imagehash
pyyaml==6.0.2 pyyaml==6.0.2
# via # via
# accelerate # accelerate
...@@ -923,7 +903,6 @@ pyyaml==6.0.2 ...@@ -923,7 +903,6 @@ pyyaml==6.0.2
# huggingface-hub # huggingface-hub
# jsonargparse # jsonargparse
# lightning # lightning
# mlflow-skinny
# omegaconf # omegaconf
# optuna # optuna
# peft # peft
...@@ -934,6 +913,7 @@ pyyaml==6.0.2 ...@@ -934,6 +913,7 @@ pyyaml==6.0.2
# timm # timm
# transformers # transformers
# vocos # vocos
# wandb
rapidfuzz==3.12.1 rapidfuzz==3.12.1
# via jiwer # via jiwer
rasterio==1.4.3 rasterio==1.4.3
...@@ -951,6 +931,7 @@ referencing==0.35.1 ...@@ -951,6 +931,7 @@ referencing==0.35.1
# jsonschema-specifications # jsonschema-specifications
regex==2024.9.11 regex==2024.9.11
# via # via
# diffusers
# nltk # nltk
# open-clip-torch # open-clip-torch
# sacrebleu # sacrebleu
...@@ -959,8 +940,8 @@ regex==2024.9.11 ...@@ -959,8 +940,8 @@ regex==2024.9.11
requests==2.32.3 requests==2.32.3
# via # via
# buildkite-test-collector # buildkite-test-collector
# databricks-sdk
# datasets # datasets
# diffusers
# docker # docker
# evaluate # evaluate
# google-api-core # google-api-core
...@@ -970,15 +951,16 @@ requests==2.32.3 ...@@ -970,15 +951,16 @@ requests==2.32.3
# lightly # lightly
# lm-eval # lm-eval
# mistral-common # mistral-common
# mlflow-skinny
# mteb # mteb
# pooch # pooch
# ray # ray
# responses # responses
# schemathesis # schemathesis
# starlette-testclient # starlette-testclient
# tacoreader
# tiktoken # tiktoken
# transformers # transformers
# wandb
responses==0.25.3 responses==0.25.3
# via genai-perf # via genai-perf
rfc3339-validator==0.1.4 rfc3339-validator==0.1.4
...@@ -991,6 +973,7 @@ rich==13.9.4 ...@@ -991,6 +973,7 @@ rich==13.9.4
# lightning # lightning
# mteb # mteb
# perceptron # perceptron
# terratorch
# typer # typer
rioxarray==0.19.0 rioxarray==0.19.0
# via terratorch # via terratorch
...@@ -1017,47 +1000,55 @@ sacrebleu==2.4.3 ...@@ -1017,47 +1000,55 @@ sacrebleu==2.4.3
safetensors==0.4.5 safetensors==0.4.5
# via # via
# accelerate # accelerate
# diffusers
# open-clip-torch # open-clip-torch
# peft # peft
# segmentation-models-pytorch
# timm # timm
# transformers # transformers
schemathesis==3.39.15 schemathesis==3.39.15
# via -r requirements/test.in # via -r requirements/test.in
scikit-image==0.25.2 scikit-image==0.25.2
# via albumentations # via
# albumentations
# terratorch
scikit-learn==1.5.2 scikit-learn==1.5.2
# via # via
# albumentations # albumentations
# librosa # librosa
# lm-eval # lm-eval
# mlflow
# mteb # mteb
# sentence-transformers # sentence-transformers
# terratorch
scipy==1.13.1 scipy==1.13.1
# via # via
# albumentations # albumentations
# bm25s # bm25s
# imagehash
# librosa # librosa
# mlflow
# mteb # mteb
# scikit-image # scikit-image
# scikit-learn # scikit-learn
# sentence-transformers # sentence-transformers
# statsmodels # statsmodels
# vocos # vocos
segmentation-models-pytorch==0.4.0 segmentation-models-pytorch==0.5.0
# via # via
# -r requirements/test.in
# terratorch # terratorch
# torchgeo # torchgeo
sentence-transformers==5.2.0 sentence-transformers==5.2.0
# via # via
# -r requirements/test.in # -r requirements/test.in
# mteb # mteb
sentry-sdk==2.52.0
# via wandb
setuptools==77.0.3 setuptools==77.0.3
# via # via
# grpcio-tools # grpcio-tools
# lightning-utilities # lightning-utilities
# pytablewriter # pytablewriter
# tensorboard
# torch # torch
shapely==2.1.1 shapely==2.1.1
# via # via
...@@ -1075,7 +1066,6 @@ six==1.16.0 ...@@ -1075,7 +1066,6 @@ six==1.16.0
# python-dateutil # python-dateutil
# rfc3339-validator # rfc3339-validator
# rouge-score # rouge-score
# segmentation-models-pytorch
smart-open==7.1.0 smart-open==7.1.0
# via ray # via ray
smmap==5.0.2 smmap==5.0.2
...@@ -1099,12 +1089,9 @@ soxr==0.5.0.post1 ...@@ -1099,12 +1089,9 @@ soxr==0.5.0.post1
sqlalchemy==2.0.41 sqlalchemy==2.0.41
# via # via
# alembic # alembic
# mlflow
# optuna # optuna
sqlitedict==2.1.0 sqlitedict==2.1.0
# via lm-eval # via lm-eval
sqlparse==0.5.3
# via mlflow-skinny
starlette==0.50.0 starlette==0.50.0
# via # via
# fastapi # fastapi
...@@ -1124,6 +1111,8 @@ tabledata==1.3.3 ...@@ -1124,6 +1111,8 @@ tabledata==1.3.3
# via pytablewriter # via pytablewriter
tabulate==0.9.0 tabulate==0.9.0
# via sacrebleu # via sacrebleu
tacoreader==0.5.6
# via terratorch
tblib==3.1.0 tblib==3.1.0
# via -r requirements/test.in # via -r requirements/test.in
tcolorpy==0.1.6 tcolorpy==0.1.6
...@@ -1133,13 +1122,19 @@ tenacity==9.1.2 ...@@ -1133,13 +1122,19 @@ tenacity==9.1.2
# gpt-oss # gpt-oss
# lm-eval # lm-eval
# plotly # plotly
tensorboard==2.20.0
# via terratorch
tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.4 tensorboardx==2.6.4
# via lightning # via lightning
tensorizer==2.10.1 tensorizer==2.10.1
# via -r requirements/test.in # via -r requirements/test.in
termcolor==3.1.0 termcolor==3.1.0
# via gpt-oss # via
terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e # gpt-oss
# terratorch
terratorch==1.2.2
# via -r requirements/test.in # via -r requirements/test.in
threadpoolctl==3.5.0 threadpoolctl==3.5.0
# via scikit-learn # via scikit-learn
...@@ -1172,7 +1167,6 @@ torch==2.10.0+cu129 ...@@ -1172,7 +1167,6 @@ torch==2.10.0+cu129
# -r requirements/test.in # -r requirements/test.in
# accelerate # accelerate
# bitsandbytes # bitsandbytes
# efficientnet-pytorch
# encodec # encodec
# kornia # kornia
# lightly # lightly
...@@ -1181,7 +1175,6 @@ torch==2.10.0+cu129 ...@@ -1181,7 +1175,6 @@ torch==2.10.0+cu129
# mteb # mteb
# open-clip-torch # open-clip-torch
# peft # peft
# pretrainedmodels
# pytorch-lightning # pytorch-lightning
# runai-model-streamer # runai-model-streamer
# segmentation-models-pytorch # segmentation-models-pytorch
...@@ -1213,12 +1206,11 @@ torchvision==0.25.0+cu129 ...@@ -1213,12 +1206,11 @@ torchvision==0.25.0+cu129
# -r requirements/test.in # -r requirements/test.in
# lightly # lightly
# open-clip-torch # open-clip-torch
# pretrainedmodels
# segmentation-models-pytorch # segmentation-models-pytorch
# terratorch # terratorch
# timm # timm
# torchgeo # torchgeo
tqdm==4.66.6 tqdm==4.67.3
# via # via
# datasets # datasets
# evaluate # evaluate
...@@ -1232,10 +1224,11 @@ tqdm==4.66.6 ...@@ -1232,10 +1224,11 @@ tqdm==4.66.6
# optuna # optuna
# peft # peft
# pqdm # pqdm
# pretrainedmodels
# pytorch-lightning # pytorch-lightning
# segmentation-models-pytorch # segmentation-models-pytorch
# sentence-transformers # sentence-transformers
# tacoreader
# terratorch
# tqdm-multiprocess # tqdm-multiprocess
# transformers # transformers
tqdm-multiprocess==0.0.11 tqdm-multiprocess==0.0.11
...@@ -1274,14 +1267,12 @@ typing-extensions==4.15.0 ...@@ -1274,14 +1267,12 @@ typing-extensions==4.15.0
# alembic # alembic
# chz # chz
# fastapi # fastapi
# graphene
# grpcio # grpcio
# huggingface-hub # huggingface-hub
# librosa # librosa
# lightning # lightning
# lightning-utilities # lightning-utilities
# mistral-common # mistral-common
# mlflow-skinny
# mteb # mteb
# opentelemetry-api # opentelemetry-api
# opentelemetry-sdk # opentelemetry-sdk
...@@ -1299,6 +1290,7 @@ typing-extensions==4.15.0 ...@@ -1299,6 +1290,7 @@ typing-extensions==4.15.0
# typer # typer
# typeshed-client # typeshed-client
# typing-inspection # typing-inspection
# wandb
typing-inspection==0.4.2 typing-inspection==0.4.2
# via pydantic # via pydantic
tzdata==2024.2 tzdata==2024.2
...@@ -1313,25 +1305,26 @@ urllib3==2.2.3 ...@@ -1313,25 +1305,26 @@ urllib3==2.2.3
# lightly # lightly
# requests # requests
# responses # responses
# sentry-sdk
# tritonclient # tritonclient
uvicorn==0.35.0 uvicorn==0.35.0
# via # via gpt-oss
# gpt-oss
# mlflow-skinny
vector-quantize-pytorch==1.21.2 vector-quantize-pytorch==1.21.2
# via -r requirements/test.in # via -r requirements/test.in
virtualenv==20.31.2 virtualenv==20.31.2
# via ray # via ray
vocos==0.1.0 vocos==0.1.0
# via -r requirements/test.in # via -r requirements/test.in
wandb==0.24.2
# via terratorch
wcwidth==0.2.13 wcwidth==0.2.13
# via ftfy # via ftfy
webcolors==24.11.1 webcolors==24.11.1
# via jsonschema # via jsonschema
werkzeug==3.1.3 werkzeug==3.1.3
# via # via
# flask
# schemathesis # schemathesis
# tensorboard
word2number==1.1 word2number==1.1
# via lm-eval # via lm-eval
wrapt==1.17.2 wrapt==1.17.2
......
...@@ -40,7 +40,7 @@ def _run_test( ...@@ -40,7 +40,7 @@ def _run_test(
vllm_model.llm.encode(prompt, pooling_task="plugin") vllm_model.llm.encode(prompt, pooling_task="plugin")
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] MODELS = ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
@pytest.mark.core_model @pytest.mark.core_model
......
...@@ -13,7 +13,7 @@ from tests.utils import create_new_process_for_each_test ...@@ -13,7 +13,7 @@ from tests.utils import create_new_process_for_each_test
"model", "model",
[ [
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
"mgazz/Prithvi_v2_eo_300_tl_unet_agb", "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars",
], ],
) )
def test_inference( def test_inference(
......
...@@ -44,12 +44,8 @@ datamodule_config: DataModuleConfig = { ...@@ -44,12 +44,8 @@ datamodule_config: DataModuleConfig = {
"no_label_replace": -1, "no_label_replace": -1,
"num_workers": 8, "num_workers": 8,
"test_transform": [ "test_transform": [
albumentations.Resize( albumentations.Resize(height=448, interpolation=1, p=1, width=448),
always_apply=False, height=448, interpolation=1, p=1, width=448 albumentations.pytorch.ToTensorV2(transpose_mask=False, p=1.0),
),
albumentations.pytorch.ToTensorV2(
transpose_mask=False, always_apply=True, p=1.0
),
], ],
} }
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64 import base64
import io
import imagehash
import pytest import pytest
import requests import requests
from PIL import Image
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" models_config = {
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11": {
"image_url": "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff", # noqa: E501
"out_hash": "aa6d92ad25926a5e",
"plugin": "prithvi_to_tiff",
},
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars": {
"image_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars/resolve/main/examples/subsetted_512x512_HLS.S30.T10SEH.2018190.v1.4_merged.tif", # noqa: E501
"out_hash": "c07f4f602da73552",
"plugin": "prithvi_to_tiff",
},
}
def _compute_image_hash(base64_data: str) -> str:
# Decode the base64 output and create image from byte stream
decoded_image = base64.b64decode(base64_data)
image = Image.open(io.BytesIO(decoded_image))
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 # Compute perceptual hash of the output image
return str(imagehash.phash(image))
def test_loading_missing_plugin(): def test_loading_missing_plugin():
...@@ -22,33 +43,39 @@ def test_loading_missing_plugin(): ...@@ -22,33 +43,39 @@ def test_loading_missing_plugin():
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def server(): def server(model_name, plugin):
args = [ args = [
"--runner", "--runner",
"pooling", "pooling",
"--enforce-eager", "--enforce-eager",
"--trust-remote-code",
"--skip-tokenizer-init", "--skip-tokenizer-init",
# Limit the maximum number of parallel requests # Limit the maximum number of parallel requests
# to avoid the model going OOM in CI. # to avoid the model going OOM in CI.
"--max-num-seqs", "--max-num-seqs",
"32", "32",
"--io-processor-plugin", "--io-processor-plugin",
"prithvi_to_tiff", plugin,
"--model-impl",
"terratorch",
"--enable-mm-embeds", "--enable-mm-embeds",
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(model_name, args) as remote_server:
yield remote_server yield remote_server
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize(
"model_name, image_url, plugin, expected_hash",
[
(model_name, config["image_url"], config["plugin"], config["out_hash"])
for model_name, config in models_config.items()
],
)
async def test_prithvi_mae_plugin_online( async def test_prithvi_mae_plugin_online(
server: RemoteOpenAIServer, server: RemoteOpenAIServer,
model_name: str, model_name: str,
image_url: str | dict,
plugin: str,
expected_hash: str,
): ):
request_payload_url = { request_payload_url = {
"data": { "data": {
...@@ -74,16 +101,25 @@ async def test_prithvi_mae_plugin_online( ...@@ -74,16 +101,25 @@ async def test_prithvi_mae_plugin_online(
# verify the output is formatted as expected for this plugin # verify the output is formatted as expected for this plugin
plugin_data = parsed_response.data plugin_data = parsed_response.data
assert all(plugin_data.get(attr) for attr in ["type", "format", "data"]) assert all(plugin_data.get(attr) for attr in ["type", "format", "data"])
# We just check that the output is a valid base64 string. # Compute the output image hash and compare it against the expected hash
# Raises an exception and fails the test if the string is corrupted. image_hash = _compute_image_hash(plugin_data["data"])
base64.b64decode(plugin_data["data"]) assert image_hash == expected_hash, (
f"Image hash mismatch: expected {expected_hash}, got {image_hash}"
)
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize(
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): "model_name, image_url, plugin, expected_hash",
[
(model_name, config["image_url"], config["plugin"], config["out_hash"])
for model_name, config in models_config.items()
],
)
def test_prithvi_mae_plugin_offline(
vllm_runner, model_name: str, image_url: str | dict, plugin: str, expected_hash: str
):
img_prompt = dict( img_prompt = dict(
data=image_url, data=image_url,
data_format="url", data_format="url",
...@@ -96,13 +132,12 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): ...@@ -96,13 +132,12 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
runner="pooling", runner="pooling",
skip_tokenizer_init=True, skip_tokenizer_init=True,
enable_mm_embeds=True, enable_mm_embeds=True,
trust_remote_code=True,
enforce_eager=True, enforce_eager=True,
# Limit the maximum number of parallel requests # Limit the maximum number of parallel requests
# to avoid the model going OOM in CI. # to avoid the model going OOM in CI.
max_num_seqs=1, max_num_seqs=32,
model_impl="terratorch", io_processor_plugin=plugin,
io_processor_plugin="prithvi_to_tiff", default_torch_num_threads=1,
) as llm_runner: ) as llm_runner:
pooler_output = llm_runner.get_llm().encode(img_prompt, pooling_task="plugin") pooler_output = llm_runner.get_llm().encode(img_prompt, pooling_task="plugin")
output = pooler_output[0].outputs output = pooler_output[0].outputs
...@@ -110,6 +145,8 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): ...@@ -110,6 +145,8 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
# verify the output is formatted as expected for this plugin # verify the output is formatted as expected for this plugin
assert all(hasattr(output, attr) for attr in ["type", "format", "data"]) assert all(hasattr(output, attr) for attr in ["type", "format", "data"])
# We just check that the output is a valid base64 string. # Compute the output image hash and compare it against the expected hash
# Raises an exception and fails the test if the string is corrupted. image_hash = _compute_image_hash(output.data)
base64.b64decode(output.data) assert image_hash == expected_hash, (
f"Image hash mismatch: expected {expected_hash}, got {image_hash}"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment