Unverified Commit cd3aa153 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Fix broken worker test (#1900)

parent 9b294976
...@@ -2,18 +2,19 @@ import random ...@@ -2,18 +2,19 @@ import random
import torch import torch
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.worker import Worker from vllm.worker.model_runner import ModelRunner
def test_worker_prepare_inputs_for_prompt(): def test_prepare_prompt():
worker = Worker(None, None, None) model_runner = ModelRunner(None, None, None)
worker.block_size = 16 model_runner.set_block_size(16)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
prompt_lens = [] prompt_lens = []
seq_group_metadata_list = [] seq_group_metadata_list = []
for i in range(batch_size): for i in range(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
prompt_len = i % (worker.block_size - 1) + 1 prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len) prompt_lens.append(prompt_len)
seq_data = list(range(prompt_len)) seq_data = list(range(prompt_len))
seq_group_metadata_list.append( seq_group_metadata_list.append(
...@@ -24,6 +25,7 @@ def test_worker_prepare_inputs_for_prompt(): ...@@ -24,6 +25,7 @@ def test_worker_prepare_inputs_for_prompt():
sampling_params=SamplingParams(temperature=0), sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
expected_selected_token_indices = [] expected_selected_token_indices = []
selected_token_start_idx = 0 selected_token_start_idx = 0
max_seq_len = max(prompt_lens) max_seq_len = max(prompt_lens)
...@@ -31,12 +33,15 @@ def test_worker_prepare_inputs_for_prompt(): ...@@ -31,12 +33,15 @@ def test_worker_prepare_inputs_for_prompt():
expected_selected_token_indices.append(selected_token_start_idx + expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1) prompt_len - 1)
selected_token_start_idx += max_seq_len selected_token_start_idx += max_seq_len
input_tokens, input_positions, input_metadata = worker._prepare_inputs( input_tokens, input_positions, _ = model_runner._prepare_prompt(
seq_group_metadata_list) seq_group_metadata_list)
assert input_tokens.shape == input_positions.shape == (batch_size, sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
max_seq_len) prompt_lens)
assert input_tokens.shape == (batch_size, max_seq_len)
assert input_positions.shape == (batch_size, max_seq_len)
torch.testing.assert_close(input_tokens, input_positions) torch.testing.assert_close(input_tokens, input_positions)
actual = input_metadata.selected_token_indices
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices, expected = torch.tensor(expected_selected_token_indices,
device=actual.device, device=actual.device,
dtype=actual.dtype) dtype=actual.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment