Unverified Commit 3dcb3e8b authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[3/N] Refactor scheduler for chunked prefill scheduling (#3550)

parent c64cf386
This diff is collapsed.
import time import time
from typing import Tuple from typing import Optional, Tuple
from vllm import SamplingParams from vllm import SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, Sequence, SequenceGroup from vllm.sequence import Logprob, Sequence, SequenceGroup
def create_dummy_prompt( def create_dummy_prompt(
request_id: str, request_id: str,
prompt_length: int, prompt_length: int,
block_size: int = None) -> Tuple[Sequence, SequenceGroup]: block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1,
) -> Tuple[Sequence, SequenceGroup]:
if not block_size: if not block_size:
block_size = prompt_length block_size = prompt_length
...@@ -17,8 +22,10 @@ def create_dummy_prompt( ...@@ -17,8 +22,10 @@ def create_dummy_prompt(
prompt_tokens = list(range(prompt_length)) prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
seq_group = SequenceGroup(request_id, [prompt], SamplingParams(), seq_group = SequenceGroup(
time.time(), None) request_id, [prompt],
SamplingParams(use_beam_search=use_beam_search, best_of=best_of),
time.time(), lora_request)
return prompt, seq_group return prompt, seq_group
......
This diff is collapsed.
...@@ -728,7 +728,7 @@ class LLMEngine: ...@@ -728,7 +728,7 @@ class LLMEngine:
time_per_output_tokens = [] time_per_output_tokens = []
time_e2e_requests = [] time_e2e_requests = []
if scheduler_outputs is not None: if scheduler_outputs is not None:
prompt_run = scheduler_outputs.prompt_run prompt_run = scheduler_outputs.num_prefill_groups > 0
# Number of Tokens. # Number of Tokens.
if prompt_run: if prompt_run:
......
...@@ -6,7 +6,7 @@ import socket ...@@ -6,7 +6,7 @@ import socket
import subprocess import subprocess
import uuid import uuid
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict, defaultdict
from functools import lru_cache, partial from functools import lru_cache, partial
from platform import uname from platform import uname
from typing import (Any, Awaitable, Callable, Generic, Hashable, List, from typing import (Any, Awaitable, Callable, Generic, Hashable, List,
...@@ -450,3 +450,20 @@ def maybe_expand_dim(tensor: torch.Tensor, ...@@ -450,3 +450,20 @@ def maybe_expand_dim(tensor: torch.Tensor,
if tensor.ndim < target_dims: if tensor.ndim < target_dims:
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim))) tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
return tensor return tensor
def merge_dicts(dict1: dict[Any, list[Any]],
dict2: dict[Any, list[Any]]) -> dict[Any, list[Any]]:
"""Merge 2 dicts that have key -> List of items.
When a key conflicts, the values in dict1 is prioritized.
"""
merged_dict = defaultdict(list)
for key, value in dict1.items():
merged_dict[key].extend(value)
for key, value in dict2.items():
merged_dict[key].extend(value)
return dict(merged_dict)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment