"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "ca6a95ba259aaa4c89eccdd254fa7922e31eddc2"
Unverified Commit fcc2994b authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[CI] Nits for bad initialization of SeqGroup in testing (#4748)

parent 2e7796f2
...@@ -142,8 +142,10 @@ def test_append_slot_cow(): ...@@ -142,8 +142,10 @@ def test_append_slot_cow():
child = prompt.fork(new_seq_id=2) child = prompt.fork(new_seq_id=2)
# Allocate space for the sequence group. # Allocate space for the sequence group.
seq_group = SequenceGroup("1", [prompt, child], SamplingParams(), seq_group = SequenceGroup(request_id="1",
time.time(), time.perf_counter) seqs=[prompt, child],
arrival_time=time.time(),
sampling_params=SamplingParams())
block_manager.allocate(seq_group) block_manager.allocate(seq_group)
# Fork and append a new token id. We expect a COW to be scheduled. # Fork and append a new token id. We expect a COW to be scheduled.
...@@ -303,8 +305,11 @@ def test_sliding_window_multi_seq(): ...@@ -303,8 +305,11 @@ def test_sliding_window_multi_seq():
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
parent = Sequence(1, "one two three", [0, 1, 2], block_size) parent = Sequence(1, "one two three", [0, 1, 2], block_size)
seq_group = SequenceGroup("1", [parent], SamplingParams(), time.time(), seq_group = SequenceGroup(request_id="1",
None) seqs=[parent],
arrival_time=time.time(),
sampling_params=SamplingParams(),
lora_request=None)
block_manager.allocate(seq_group) block_manager.allocate(seq_group)
# assert the number of blocks allocated is correct # assert the number of blocks allocated is correct
......
...@@ -22,10 +22,13 @@ def create_dummy_prompt( ...@@ -22,10 +22,13 @@ def create_dummy_prompt(
prompt_tokens = list(range(prompt_length)) prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
seq_group = SequenceGroup( seq_group = SequenceGroup(request_id=request_id,
request_id, [prompt], seqs=[prompt],
SamplingParams(use_beam_search=use_beam_search, best_of=best_of), arrival_time=time.time(),
time.time(), lora_request) sampling_params=SamplingParams(
use_beam_search=use_beam_search,
best_of=best_of),
lora_request=lora_request)
return prompt, seq_group return prompt, seq_group
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment