Unverified Commit 3dcb3e8b authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[3/N] Refactor scheduler for chunked prefill scheduling (#3550)

parent c64cf386
This diff is collapsed.
import time
from typing import Tuple
from typing import Optional, Tuple
from vllm import SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, Sequence, SequenceGroup
def create_dummy_prompt(
request_id: str,
prompt_length: int,
block_size: int = None) -> Tuple[Sequence, SequenceGroup]:
request_id: str,
prompt_length: int,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1,
) -> Tuple[Sequence, SequenceGroup]:
if not block_size:
block_size = prompt_length
......@@ -17,8 +22,10 @@ def create_dummy_prompt(
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
seq_group = SequenceGroup(request_id, [prompt], SamplingParams(),
time.time(), None)
seq_group = SequenceGroup(
request_id, [prompt],
SamplingParams(use_beam_search=use_beam_search, best_of=best_of),
time.time(), lora_request)
return prompt, seq_group
......
This diff is collapsed.
......@@ -728,7 +728,7 @@ class LLMEngine:
time_per_output_tokens = []
time_e2e_requests = []
if scheduler_outputs is not None:
prompt_run = scheduler_outputs.prompt_run
prompt_run = scheduler_outputs.num_prefill_groups > 0
# Number of Tokens.
if prompt_run:
......
......@@ -6,7 +6,7 @@ import socket
import subprocess
import uuid
import warnings
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from functools import lru_cache, partial
from platform import uname
from typing import (Any, Awaitable, Callable, Generic, Hashable, List,
......@@ -450,3 +450,20 @@ def maybe_expand_dim(tensor: torch.Tensor,
if tensor.ndim < target_dims:
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
return tensor
def merge_dicts(dict1: dict[Any, list[Any]],
dict2: dict[Any, list[Any]]) -> dict[Any, list[Any]]:
"""Merge 2 dicts that have key -> List of items.
When a key conflicts, the values in dict1 is prioritized.
"""
merged_dict = defaultdict(list)
for key, value in dict1.items():
merged_dict[key].extend(value)
for key, value in dict2.items():
merged_dict[key].extend(value)
return dict(merged_dict)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment