test_model_runner.py 2.04 KB
Newer Older
1
2
3
4
import random
import torch

from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
5
from vllm.worker.model_runner import ModelRunner
6
7


Woosuk Kwon's avatar
Woosuk Kwon committed
8
def test_prepare_prompt():
9
    model_runner = ModelRunner(None, None, None, None)
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
    model_runner.set_block_size(16)

12
13
14
15
16
    batch_size = random.randint(1, 256)
    prompt_lens = []
    seq_group_metadata_list = []
    for i in range(batch_size):
        # make sure all tokens fit into one block
Woosuk Kwon's avatar
Woosuk Kwon committed
17
        prompt_len = i % (model_runner.block_size - 1) + 1
18
19
20
21
22
23
24
25
26
27
        prompt_lens.append(prompt_len)
        seq_data = list(range(prompt_len))
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData(seq_data)},
                sampling_params=SamplingParams(temperature=0),
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
28

29
30
31
32
33
34
35
    expected_selected_token_indices = []
    selected_token_start_idx = 0
    max_seq_len = max(prompt_lens)
    for prompt_len in prompt_lens:
        expected_selected_token_indices.append(selected_token_start_idx +
                                               prompt_len - 1)
        selected_token_start_idx += max_seq_len
36
    input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = (
37
38
        model_runner._prepare_prompt(seq_group_metadata_list))
    assert return_prompt_lens == prompt_lens
Woosuk Kwon's avatar
Woosuk Kwon committed
39
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
40
41
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
Woosuk Kwon's avatar
Woosuk Kwon committed
42
43
    assert input_tokens.shape == (batch_size, max_seq_len)
    assert input_positions.shape == (batch_size, max_seq_len)
44
    torch.testing.assert_close(input_tokens, input_positions)
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46

    actual = sampling_metadata.selected_token_indices
47
48
49
50
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)