Unverified Commit e965d461 authored by Allen.Dou's avatar Allen.Dou Committed by GitHub
Browse files

[Misc] Keep only one implementation of the create_dummy_prompt function. (#4716)

parent 208b71bc
import time
from typing import Optional
import pytest import pytest
from vllm import SamplingParams from tests.core.utils import create_dummy_prompt
from vllm.lora.request import LoRARequest from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput,
from vllm.sequence import (SamplerOutput, Sequence, SequenceData, SequenceOutput)
SequenceGroup, SequenceGroupOutput, SequenceOutput)
def create_dummy_prompt(
request_id: str,
prompt_length: int,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1,
) -> SequenceGroup:
if not block_size:
block_size = prompt_length
# Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
seq_group = SequenceGroup(
request_id, [prompt],
SamplingParams(use_beam_search=use_beam_search, best_of=best_of),
time.time(), lora_request)
return seq_group
@pytest.fixture @pytest.fixture
...@@ -102,7 +74,7 @@ def test_sequence_data_prefill(): ...@@ -102,7 +74,7 @@ def test_sequence_data_prefill():
def test_sequence_group_stage(): def test_sequence_group_stage():
seq_group = create_dummy_prompt("1", 12) _, seq_group = create_dummy_prompt("1", 12)
assert seq_group.is_prefill() is True assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(6) seq_group.update_num_computed_tokens(6)
assert seq_group.is_prefill() is True assert seq_group.is_prefill() is True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment