test_terratorch.py 1.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from tests.conftest import VllmRunner


@pytest.mark.parametrize(
    "model",
    [
13
        "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
14
        "mgazz/Prithvi_v2_eo_300_tl_unet_agb",
15
16
17
18
19
20
21
22
    ],
)
def test_inference(
    vllm_runner: type[VllmRunner],
    model: str,
) -> None:
    pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
    location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
23
24
25
26
27
28
    prompt = dict(
        prompt_token_ids=[1],
        multi_modal_data=dict(
            pixel_values=pixel_values, location_coords=location_coords
        ),
    )
29
    with vllm_runner(
30
31
32
33
34
35
36
37
38
        model,
        runner="pooling",
        dtype="half",
        enforce_eager=True,
        skip_tokenizer_init=True,
        # Limit the maximum number of sequences to avoid the
        # test going OOM during the warmup run
        max_num_seqs=32,
        default_torch_num_threads=1,
39
    ) as vllm_model:
40
41
        vllm_output = vllm_model.llm.encode(prompt)
        assert torch.equal(
42
43
            torch.isnan(vllm_output[0].outputs.data).any(), torch.tensor(False)
        )