prithvi_geospatial_mae.py 1.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import base64
import os

import requests

# This example shows how to perform an online inference that generates
# multimodal data. In this specific case this example will take a geotiff
# image as input, process it using the multimodal data processor, and
# perform inference.
co63oc's avatar
co63oc committed
13
# Requirements :
14
15
16
17
# - install plugin at:
#   https://github.com/christian-pinto/prithvi_io_processor_plugin
# - start vllm in serving mode with the below args
#   --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
18
#   --model-impl terratorch
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#   --task embed --trust-remote-code
#   --skip-tokenizer-init --enforce-eager
#   --io-processor-plugin prithvi_to_tiff_india


def main():
    image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif"  # noqa: E501
    server_endpoint = "http://localhost:8000/pooling"

    request_payload_url = {
        "data": {
            "data": image_url,
            "data_format": "url",
            "image_format": "tiff",
            "out_data_format": "b64_json",
        },
        "priority": 0,
        "model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
37
        "softmax": False,
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    }

    ret = requests.post(server_endpoint, json=request_payload_url)

    print(f"response.status_code: {ret.status_code}")
    print(f"response.reason:{ret.reason}")

    response = ret.json()

    decoded_image = base64.b64decode(response["data"]["data"])

    out_path = os.path.join(os.getcwd(), "online_prediction.tiff")

    with open(out_path, "wb") as f:
        f.write(decoded_image)


if __name__ == "__main__":
    main()