prithvi_geospatial_mae_online.py 1.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import base64
import os

import requests

# This example shows how to perform an online inference that generates
# multimodal data. In this specific case this example will take a geotiff
# image as input, process it using the multimodal data processor, and
# perform inference.
co63oc's avatar
co63oc committed
13
# Requirements :
14
15
# - install TerraTorch v1.1 (or later):
#   pip install terratorch>=v1.1
16
17
# - start vllm in serving mode with the below args
#   --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
18
#   --model-impl terratorch
19
#   --trust-remote-code
20
#   --skip-tokenizer-init --enforce-eager
21
#   --io-processor-plugin terratorch_segmentation
22
#   --enable-mm-embeds
23
24
25


def main():
26
    image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff"  # noqa: E501
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    server_endpoint = "http://localhost:8000/pooling"

    request_payload_url = {
        "data": {
            "data": image_url,
            "data_format": "url",
            "image_format": "tiff",
            "out_data_format": "b64_json",
        },
        "priority": 0,
        "model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
    }

    ret = requests.post(server_endpoint, json=request_payload_url)

    print(f"response.status_code: {ret.status_code}")
    print(f"response.reason:{ret.reason}")

    response = ret.json()

    decoded_image = base64.b64decode(response["data"]["data"])

    out_path = os.path.join(os.getcwd(), "online_prediction.tiff")

    with open(out_path, "wb") as f:
        f.write(decoded_image)


if __name__ == "__main__":
    main()