test_terratorch.py 1.39 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from tests.conftest import VllmRunner
from vllm.utils import set_default_torch_num_threads


@pytest.mark.parametrize(
    "model",
    [
14
        "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
        "mgazz/Prithvi_v2_eo_300_tl_unet_agb"
    ],
)
def test_inference(
    vllm_runner: type[VllmRunner],
    model: str,
) -> None:

    pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
    location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
    prompt = dict(prompt_token_ids=[1],
                  multi_modal_data=dict(pixel_values=pixel_values,
                                        location_coords=location_coords))
    with (
            set_default_torch_num_threads(1),
            vllm_runner(
                model,
                runner="pooling",
                dtype=torch.float16,
                enforce_eager=True,
                skip_tokenizer_init=True,
                # Limit the maximum number of sequences to avoid the
                # test going OOM during the warmup run
                max_num_seqs=32,
            ) as vllm_model,
    ):

        vllm_output = vllm_model.llm.encode(prompt)
        assert torch.equal(
            torch.isnan(vllm_output[0].outputs.data).any(),
            torch.tensor(False))