"vscode:/vscode.git/clone" did not exist on "632d598c77af616278bc0f2144a14958678dcbae"
Commit 1132fae0 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Add Frontend

parent 46ce1356
from typing import List, Optional, Tuple
from transformers import AutoTokenizer
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.utils import Counter
class Frontend:
def __init__(
self,
model_name: str,
block_size: int,
) -> None:
self.block_size = block_size
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.seq_group_counter = Counter()
self.seq_counter = Counter()
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
def query(
self,
prompt: str,
sampling_params: Optional[SamplingParams] = None,
) -> None:
if sampling_params is None:
sampling_params = SamplingParams()
token_ids: List[int] = self.tokenizer.encode(prompt)
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
seqs.append(seq)
group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs)
self.inputs.append((seq_group, sampling_params))
def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]:
inputs = self.inputs
self.inputs = []
return inputs
def print_response(
self,
seq_group: SequenceGroup,
) -> None:
for seq in seq_group.seqs:
token_ids = seq.get_token_ids()
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
print(f'Seq {seq.seq_id}: {output}')
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from cacheflow.master.block_manager import BlockSpaceManager from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.master.frontend import Frontend
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceStatus from cacheflow.sequence import SequenceStatus
...@@ -12,11 +14,13 @@ class Scheduler: ...@@ -12,11 +14,13 @@ class Scheduler:
def __init__( def __init__(
self, self,
frontend: Frontend,
controllers: List, controllers: List,
block_size: int, block_size: int,
num_gpu_blocks: int, num_gpu_blocks: int,
num_cpu_blocks: int, num_cpu_blocks: int,
) -> None: ) -> None:
self.frontend = frontend
self.controllers = controllers self.controllers = controllers
self.block_size = block_size self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks self.num_gpu_blocks = num_gpu_blocks
...@@ -33,16 +37,20 @@ class Scheduler: ...@@ -33,16 +37,20 @@ class Scheduler:
self.running: List[SequenceGroup] = [] self.running: List[SequenceGroup] = []
# Mapping: group_id -> num_steps. # Mapping: group_id -> num_steps.
self.num_steps: Dict[int, int] = {} self.num_steps: Dict[int, int] = {}
# Mapping: group_id -> max_num_steps. # Mapping: group_id -> sampling params.
self.max_num_steps: Dict[int, int] = {} self.sampling_params: Dict[int, SamplingParams] = {}
# Mapping: group_id -> stop_token_ids.
self.stop_token_ids: Dict[int, List[int]] = {}
# Swapped sequence groups (LIFO). # Swapped sequence groups (LIFO).
self.swapped: List[SequenceGroup] = [] self.swapped: List[SequenceGroup] = []
# Pending sequence groups (FIFO). # Pending sequence groups (FIFO).
self.pending: List[SequenceGroup] = [] self.pending: List[SequenceGroup] = []
def _fetch_inputs(self) -> None:
inputs = self.frontend.get_inputs()
for seq_group, sampling_params in inputs:
self.pending.append(seq_group)
self.sampling_params[seq_group.group_id] = sampling_params
def _free_seq(self, seq: Sequence) -> None: def _free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq) self.block_manager.free(seq)
...@@ -145,6 +153,7 @@ class Scheduler: ...@@ -145,6 +153,7 @@ class Scheduler:
# TODO(woosuk): Add a batching policy to control the batch size. # TODO(woosuk): Add a batching policy to control the batch size.
if not self.swapped: if not self.swapped:
# FIXME(woosuk): Acquire a lock to protect pending. # FIXME(woosuk): Acquire a lock to protect pending.
self._fetch_inputs()
for i, seq_group in enumerate(self.pending): for i, seq_group in enumerate(self.pending):
num_prompt_tokens = seq_group.seqs[0].get_len() num_prompt_tokens = seq_group.seqs[0].get_len()
if self.block_manager.can_allocate(seq_group): if self.block_manager.can_allocate(seq_group):
...@@ -205,7 +214,7 @@ class Scheduler: ...@@ -205,7 +214,7 @@ class Scheduler:
for seq_group in self.running: for seq_group in self.running:
group_id = seq_group.group_id group_id = seq_group.group_id
self.num_steps[group_id] += 1 self.num_steps[group_id] += 1
stop_token_ids = self.stop_token_ids[group_id] stop_token_ids = self.sampling_params[group_id].stop_token_ids
for seq in seq_group.seqs: for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED: if seq.status == SequenceStatus.FINISHED:
...@@ -230,24 +239,22 @@ class Scheduler: ...@@ -230,24 +239,22 @@ class Scheduler:
continue continue
# Check if the sequence has reached the maximum number of steps. # Check if the sequence has reached the maximum number of steps.
if self.num_steps[group_id] == self.max_num_steps[group_id]: max_num_steps = self.sampling_params[group_id].max_num_steps
if self.num_steps[group_id] == max_num_steps:
self._free_seq(seq) self._free_seq(seq)
continue continue
# Update the running sequences. # Update the running sequences.
running: List[SequenceGroup] = [] running: List[SequenceGroup] = []
for seq_group in self.running: for seq_group in self.running:
if all(seq.status == SequenceStatus.FINISHED for seq in seq_group.seqs): if seq_group.is_finished():
del self.num_steps[seq_group.group_id] self._return(seq_group)
del self.max_num_steps[seq_group.group_id]
del self.stop_token_ids[seq_group.group_id]
# TODO: Return the seq_group to the client.
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')
for seq in seq_group.seqs:
token_ids = seq.get_token_ids()
output = tokenizer.decode(token_ids, skip_special_tokens=True)
print(f'Seq {seq.seq_id}: {output}')
else: else:
running.append(seq_group) running.append(seq_group)
self.running = running self.running = running
def _return(self, seq_group: SequenceGroup) -> None:
group_id = seq_group.group_id
del self.num_steps[group_id]
del self.sampling_params[group_id]
self.frontend.print_response(seq_group)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment