test_terratorch.py 1.47 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from tests.conftest import VllmRunner
8
from tests.utils import create_new_process_for_each_test
9
10


11
@create_new_process_for_each_test()  # Memory is not cleaned up properly otherwise
12
13
14
@pytest.mark.parametrize(
    "model",
    [
15
        "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
16
        "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars",
17
18
19
20
21
22
23
24
    ],
)
def test_inference(
    vllm_runner: type[VllmRunner],
    model: str,
) -> None:
    pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
    location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
25
26
    prompt = dict(
        prompt_token_ids=[1],
27
28
29
30
31
32
        multi_modal_data={
            "image": {
                "pixel_values": pixel_values,
                "location_coords": location_coords,
            }
        },
33
    )
34

35
    with vllm_runner(
36
37
38
39
40
        model,
        runner="pooling",
        dtype="half",
        enforce_eager=True,
        skip_tokenizer_init=True,
41
        enable_mm_embeds=True,
42
43
44
45
        # Limit the maximum number of sequences to avoid the
        # test going OOM during the warmup run
        max_num_seqs=32,
        default_torch_num_threads=1,
46
    ) as vllm_model:
47
        vllm_output = vllm_model.llm.encode(prompt, pooling_task="plugin")
48
        assert torch.equal(
49
50
            torch.isnan(vllm_output[0].outputs.data).any(), torch.tensor(False)
        )