"include/ck/utility/functional2.hpp" did not exist on "1b3c2e403585bf1884b195289b7e863d2924379d"
Commit 1132fae0 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Add Frontend

parent 46ce1356
from typing import List, Optional, Tuple
from transformers import AutoTokenizer
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.utils import Counter
class Frontend:
def __init__(
self,
model_name: str,
block_size: int,
) -> None:
self.block_size = block_size
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.seq_group_counter = Counter()
self.seq_counter = Counter()
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
def query(
self,
prompt: str,
sampling_params: Optional[SamplingParams] = None,
) -> None:
if sampling_params is None:
sampling_params = SamplingParams()
token_ids: List[int] = self.tokenizer.encode(prompt)
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
seqs.append(seq)
group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs)
self.inputs.append((seq_group, sampling_params))
def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]:
inputs = self.inputs
self.inputs = []
return inputs
def print_response(
self,
seq_group: SequenceGroup,
) -> None:
for seq in seq_group.seqs:
token_ids = seq.get_token_ids()
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
print(f'Seq {seq.seq_id}: {output}')
from typing import Dict, List, Tuple
from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.master.frontend import Frontend
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceStatus
......@@ -12,11 +14,13 @@ class Scheduler:
def __init__(
self,
frontend: Frontend,
controllers: List,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
) -> None:
self.frontend = frontend
self.controllers = controllers
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
......@@ -33,16 +37,20 @@ class Scheduler:
self.running: List[SequenceGroup] = []
# Mapping: group_id -> num_steps.
self.num_steps: Dict[int, int] = {}
# Mapping: group_id -> max_num_steps.
self.max_num_steps: Dict[int, int] = {}
# Mapping: group_id -> stop_token_ids.
self.stop_token_ids: Dict[int, List[int]] = {}
# Mapping: group_id -> sampling params.
self.sampling_params: Dict[int, SamplingParams] = {}
# Swapped sequence groups (LIFO).
self.swapped: List[SequenceGroup] = []
# Pending sequence groups (FIFO).
self.pending: List[SequenceGroup] = []
def _fetch_inputs(self) -> None:
inputs = self.frontend.get_inputs()
for seq_group, sampling_params in inputs:
self.pending.append(seq_group)
self.sampling_params[seq_group.group_id] = sampling_params
def _free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq)
......@@ -145,6 +153,7 @@ class Scheduler:
# TODO(woosuk): Add a batching policy to control the batch size.
if not self.swapped:
# FIXME(woosuk): Acquire a lock to protect pending.
self._fetch_inputs()
for i, seq_group in enumerate(self.pending):
num_prompt_tokens = seq_group.seqs[0].get_len()
if self.block_manager.can_allocate(seq_group):
......@@ -205,7 +214,7 @@ class Scheduler:
for seq_group in self.running:
group_id = seq_group.group_id
self.num_steps[group_id] += 1
stop_token_ids = self.stop_token_ids[group_id]
stop_token_ids = self.sampling_params[group_id].stop_token_ids
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
......@@ -230,24 +239,22 @@ class Scheduler:
continue
# Check if the sequence has reached the maximum number of steps.
if self.num_steps[group_id] == self.max_num_steps[group_id]:
max_num_steps = self.sampling_params[group_id].max_num_steps
if self.num_steps[group_id] == max_num_steps:
self._free_seq(seq)
continue
# Update the running sequences.
running: List[SequenceGroup] = []
for seq_group in self.running:
if all(seq.status == SequenceStatus.FINISHED for seq in seq_group.seqs):
del self.num_steps[seq_group.group_id]
del self.max_num_steps[seq_group.group_id]
del self.stop_token_ids[seq_group.group_id]
# TODO: Return the seq_group to the client.
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')
for seq in seq_group.seqs:
token_ids = seq.get_token_ids()
output = tokenizer.decode(token_ids, skip_special_tokens=True)
print(f'Seq {seq.seq_id}: {output}')
if seq_group.is_finished():
self._return(seq_group)
else:
running.append(seq_group)
self.running = running
def _return(self, seq_group: SequenceGroup) -> None:
group_id = seq_group.group_id
del self.num_steps[group_id]
del self.sampling_params[group_id]
self.frontend.print_response(seq_group)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment