"benchmarks/vscode:/vscode.git/clone" did not exist on "978aed53004b82877bd2af0f10afff1826d7194d"
Unverified Commit 9cd76b71 authored by Christian Pinto's avatar Christian Pinto Committed by GitHub
Browse files

[Misc] Terratorch related fixes (#24337)


Signed-off-by: default avatarChristian Pinto <christian.pinto@ibm.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent e0413141
...@@ -18,7 +18,7 @@ from vllm.pooling_params import PoolingParams ...@@ -18,7 +18,7 @@ from vllm.pooling_params import PoolingParams
def main(): def main():
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float16)
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501 image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
img_prompt = dict( img_prompt = dict(
data=image_url, data=image_url,
...@@ -36,7 +36,7 @@ def main(): ...@@ -36,7 +36,7 @@ def main():
# to avoid the model going OOM. # to avoid the model going OOM.
# The maximum number depends on the available GPU memory # The maximum number depends on the available GPU memory
max_num_seqs=32, max_num_seqs=32,
io_processor_plugin="prithvi_to_tiff_india", io_processor_plugin="prithvi_to_tiff",
model_impl="terratorch", model_impl="terratorch",
) )
......
...@@ -18,11 +18,11 @@ import requests ...@@ -18,11 +18,11 @@ import requests
# --model-impl terratorch # --model-impl terratorch
# --task embed --trust-remote-code # --task embed --trust-remote-code
# --skip-tokenizer-init --enforce-eager # --skip-tokenizer-init --enforce-eager
# --io-processor-plugin prithvi_to_tiff_india # --io-processor-plugin prithvi_to_tiff
def main(): def main():
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501 image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
server_endpoint = "http://localhost:8000/pooling" server_endpoint = "http://localhost:8000/pooling"
request_payload_url = { request_payload_url = {
......
...@@ -54,4 +54,4 @@ runai-model-streamer-s3==0.11.0 ...@@ -54,4 +54,4 @@ runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10 fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10 pydantic>=2.10 # 2.9 leads to error on python 3.10
decord==0.6.0 decord==0.6.0
terratorch==1.1rc3 # required for PrithviMAE test terratorch @ git+https://github.com/IBM/terratorch.git@1.1.rc3 # required for PrithviMAE test
...@@ -1042,7 +1042,7 @@ tensorboardx==2.6.4 ...@@ -1042,7 +1042,7 @@ tensorboardx==2.6.4
# via lightning # via lightning
tensorizer==2.10.1 tensorizer==2.10.1
# via -r requirements/test.in # via -r requirements/test.in
terratorch==1.1rc3 terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e
# via -r requirements/test.in # via -r requirements/test.in
threadpoolctl==3.5.0 threadpoolctl==3.5.0
# via scikit-learn # via scikit-learn
......
...@@ -11,7 +11,7 @@ import torch ...@@ -11,7 +11,7 @@ import torch
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11" MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
DTYPE = "float16" DTYPE = "float16"
......
...@@ -383,7 +383,7 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -383,7 +383,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
trust_remote_code=True), trust_remote_code=True),
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
"PrithviGeoSpatialMAE": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 "PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
dtype=torch.float16, dtype=torch.float16,
enforce_eager=True, enforce_eager=True,
skip_tokenizer_init=True, skip_tokenizer_init=True,
...@@ -391,7 +391,7 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -391,7 +391,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
# going OOM in CI # going OOM in CI
max_num_seqs=32, max_num_seqs=32,
), ),
"Terratorch": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", "Terratorch": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
dtype=torch.float16, dtype=torch.float16,
enforce_eager=True, enforce_eager=True,
skip_tokenizer_init=True, skip_tokenizer_init=True,
......
...@@ -11,7 +11,7 @@ from vllm.utils import set_default_torch_num_threads ...@@ -11,7 +11,7 @@ from vllm.utils import set_default_torch_num_threads
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
"mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
"mgazz/Prithvi_v2_eo_300_tl_unet_agb" "mgazz/Prithvi_v2_eo_300_tl_unet_agb"
], ],
) )
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def register_prithvi_india():
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorIndia" # noqa: E501
def register_prithvi_valencia(): def register_prithvi():
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorValencia" # noqa: E501 return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessor" # noqa: E501
...@@ -234,6 +234,8 @@ def load_image( ...@@ -234,6 +234,8 @@ def load_image(
class PrithviMultimodalDataProcessor(IOProcessor): class PrithviMultimodalDataProcessor(IOProcessor):
indices = [0, 1, 2, 3, 4, 5]
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config) super().__init__(vllm_config)
...@@ -412,21 +414,3 @@ class PrithviMultimodalDataProcessor(IOProcessor): ...@@ -412,21 +414,3 @@ class PrithviMultimodalDataProcessor(IOProcessor):
format="tiff", format="tiff",
data=out_data, data=out_data,
request_id=request_id) request_id=request_id)
class PrithviMultimodalDataProcessorIndia(PrithviMultimodalDataProcessor):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.indices = [1, 2, 3, 8, 11, 12]
class PrithviMultimodalDataProcessorValencia(PrithviMultimodalDataProcessor):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.indices = [0, 1, 2, 3, 4, 5]
...@@ -9,8 +9,7 @@ setup( ...@@ -9,8 +9,7 @@ setup(
packages=["prithvi_io_processor"], packages=["prithvi_io_processor"],
entry_points={ entry_points={
"vllm.io_processor_plugins": [ "vllm.io_processor_plugins": [
"prithvi_to_tiff_india = prithvi_io_processor:register_prithvi_india", # noqa: E501 "prithvi_to_tiff = prithvi_io_processor:register_prithvi", # noqa: E501
"prithvi_to_tiff_valencia = prithvi_io_processor:register_prithvi_valencia", # noqa: E501
] ]
}, },
) )
...@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import IOProcessorResponse ...@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11" MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
...@@ -35,7 +35,7 @@ def server(): ...@@ -35,7 +35,7 @@ def server():
"--max-num-seqs", "--max-num-seqs",
"32", "32",
"--io-processor-plugin", "--io-processor-plugin",
"prithvi_to_tiff_valencia", "prithvi_to_tiff",
"--model-impl", "--model-impl",
"terratorch", "terratorch",
] ]
...@@ -107,7 +107,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): ...@@ -107,7 +107,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
# to avoid the model going OOM in CI. # to avoid the model going OOM in CI.
max_num_seqs=1, max_num_seqs=1,
model_impl="terratorch", model_impl="terratorch",
io_processor_plugin="prithvi_to_tiff_valencia", io_processor_plugin="prithvi_to_tiff",
) as llm_runner: ) as llm_runner:
pooler_output = llm_runner.get_llm().encode( pooler_output = llm_runner.get_llm().encode(
img_prompt, img_prompt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment