"vllm/entrypoints/openai/chat_completion/serving.py" did not exist on "9bb38130cb19eb084d39f269cbeae2952789fafd"
test_terratorch.py 1.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from tests.conftest import VllmRunner


@pytest.mark.parametrize(
    "model",
    [
13
        "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
14
        "mgazz/Prithvi_v2_eo_300_tl_unet_agb",
15
16
17
18
19
20
21
22
    ],
)
def test_inference(
    vllm_runner: type[VllmRunner],
    model: str,
) -> None:
    pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
    location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
23
24
25
26
27
28
    prompt = dict(
        prompt_token_ids=[1],
        multi_modal_data=dict(
            pixel_values=pixel_values, location_coords=location_coords
        ),
    )
29
    with vllm_runner(
30
31
32
33
34
        model,
        runner="pooling",
        dtype="half",
        enforce_eager=True,
        skip_tokenizer_init=True,
35
        enable_mm_embeds=True,
36
37
38
39
        # Limit the maximum number of sequences to avoid the
        # test going OOM during the warmup run
        max_num_seqs=32,
        default_torch_num_threads=1,
40
    ) as vllm_model:
41
        vllm_output = vllm_model.llm.encode(prompt, pooling_task="plugin")
42
        assert torch.equal(
43
44
            torch.isnan(vllm_output[0].outputs.data).any(), torch.tensor(False)
        )