Unverified Commit 837e1851 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[CI/Build] fix flaky test (#3602)

parent 42bc3861
import random import pytest
import torch import torch
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
def get_aligned_size(batch_size: int, alignment: int): @pytest.mark.parametrize("batch_size", list(range(1, 257)))
return ((batch_size + alignment - 1) // alignment * alignment) def test_prepare_prompt(batch_size):
def test_prepare_prompt():
model_runner = ModelRunner(None, None, None, None, None) model_runner = ModelRunner(None, None, None, None, None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
batch_size = random.randint(1, 256)
prompt_lens = [] prompt_lens = []
seq_group_metadata_list = [] seq_group_metadata_list = []
block_tables = {0: [1]} block_tables = {0: [1]}
...@@ -111,7 +107,8 @@ def test_prepare_prompt(): ...@@ -111,7 +107,8 @@ def test_prepare_prompt():
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
def test_prepare_decode_cuda_graph(): @pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
model_config = ModelConfig( model_config = ModelConfig(
"facebook/opt-125m", "facebook/opt-125m",
"facebook/opt-125m", "facebook/opt-125m",
...@@ -127,7 +124,6 @@ def test_prepare_decode_cuda_graph(): ...@@ -127,7 +124,6 @@ def test_prepare_decode_cuda_graph():
model_runner = ModelRunner(model_config, None, None, None, None) model_runner = ModelRunner(model_config, None, None, None, None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
batch_size = random.randint(1, 256)
prompt_lens = [] prompt_lens = []
seq_group_metadata_list = [] seq_group_metadata_list = []
for i in range(batch_size): for i in range(batch_size):
...@@ -147,13 +143,13 @@ def test_prepare_decode_cuda_graph(): ...@@ -147,13 +143,13 @@ def test_prepare_decode_cuda_graph():
input_tokens, input_positions, input_metadata, _, _, _ = ( input_tokens, input_positions, input_metadata, _, _, _ = (
model_runner._prepare_decode(seq_group_metadata_list)) model_runner._prepare_decode(seq_group_metadata_list))
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
assert input_metadata.is_prompt is False assert input_metadata.is_prompt is False
assert input_metadata.prompt_lens is None assert input_metadata.prompt_lens is None
assert input_metadata.num_prompt_tokens == 0 assert input_metadata.num_prompt_tokens == 0
assert input_metadata.num_generation_tokens == (get_aligned_size( assert input_metadata.num_generation_tokens == expected_bs
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT))
assert input_metadata.max_seq_len is None assert input_metadata.max_seq_len is None
assert input_metadata.subquery_start_loc is None assert input_metadata.subquery_start_loc is None
assert input_metadata.seq_start_loc is None assert input_metadata.seq_start_loc is None
...@@ -173,10 +169,8 @@ def test_prepare_decode_cuda_graph(): ...@@ -173,10 +169,8 @@ def test_prepare_decode_cuda_graph():
assert input_metadata.use_cuda_graph is True assert input_metadata.use_cuda_graph is True
assert input_metadata.kv_cache_dtype == "auto" assert input_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (get_aligned_size( assert input_tokens.shape == (expected_bs, )
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), ) assert input_positions.shape == (expected_bs, )
assert input_positions.shape == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
torch.testing.assert_close(input_tokens, input_positions) torch.testing.assert_close(input_tokens, input_positions)
# Verify Sampling # Verify Sampling
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment