Commit fa16389a authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Clean up the server script

parent 6aef2278
import argparse import argparse
from typing import List from typing import List
from cacheflow.master.frontend import Frontend
from cacheflow.master.scheduler import Scheduler from cacheflow.master.scheduler import Scheduler
from cacheflow.worker.controller import Controller from cacheflow.worker.controller import Controller
...@@ -8,15 +9,15 @@ parser = argparse.ArgumentParser(description='CacheFlow server') ...@@ -8,15 +9,15 @@ parser = argparse.ArgumentParser(description='CacheFlow server')
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name') parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes') parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes')
parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node') parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node')
parser.add_argument('--block-size', type=int, default=8, help='block size') parser.add_argument('--block-size', type=int, default=8, help='token block size')
parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks') # TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks.
parser.add_argument('--num-cpu-blocks', type=int, default=256, help='number of CPU blocks') parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks (per GPU)')
parser.add_argument('--num-cpu-blocks', type=int, default=256, help='number of CPU blocks (per GPU)')
args = parser.parse_args()
def main(): def main():
args = parser.parse_args() # Create a controller for each node.
# Create controllers.
controllers: List[Controller] = [] controllers: List[Controller] = []
for i in range(args.num_nodes): for i in range(args.num_nodes):
controller = Controller( controller = Controller(
...@@ -26,12 +27,18 @@ def main(): ...@@ -26,12 +27,18 @@ def main():
block_size=args.block_size, block_size=args.block_size,
num_gpu_blocks=args.num_gpu_blocks, num_gpu_blocks=args.num_gpu_blocks,
num_cpu_blocks=args.num_cpu_blocks, num_cpu_blocks=args.num_cpu_blocks,
dtype='float',
) )
controllers.append(controller) controllers.append(controller)
# Create a frontend.
frontend = Frontend(
model_name=args.model,
block_size=args.block_size,
)
# Create a scheduler. # Create a scheduler.
scheduler = Scheduler( scheduler = Scheduler(
frontend=frontend,
controllers=controllers, controllers=controllers,
block_size=args.block_size, block_size=args.block_size,
num_gpu_blocks=args.num_gpu_blocks, num_gpu_blocks=args.num_gpu_blocks,
...@@ -42,65 +49,19 @@ def main(): ...@@ -42,65 +49,19 @@ def main():
controllers[i].set_next(controllers[i + 1]) controllers[i].set_next(controllers[i + 1])
controllers[-1].set_next(scheduler) controllers[-1].set_next(scheduler)
# seq_groups, max_num_steps, stop_token_ids = generate_inputs(1000, args.block_size) test_inputs = [
seq_groups, max_num_steps, stop_token_ids = test_inputs(args.block_size) 'Ion Stoica is a',
scheduler.pending.extend(seq_groups) 'UC Berkeley is',
scheduler.max_num_steps.update(max_num_steps) 'The future of cloud computing is',
scheduler.stop_token_ids.update(stop_token_ids) ]
for prompt in test_inputs:
frontend.query(prompt)
while scheduler.pending or scheduler.running: # FIXME
scheduler.prepare() while True:
scheduler.step() scheduler.step()
if not scheduler.pending and not scheduler.running:
break
def test_inputs(block_size):
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.utils import Counter
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')
prompt = "Hello, I'm am conscious and"
prompt_tokens = tokenizer.encode(prompt)
seq = Sequence(0, prompt_tokens, block_size=block_size)
seq_group = SequenceGroup(0, [seq])
seq_groups = [seq_group]
max_num_steps = {0: 8}
stop_token_ids = {0: []}
return seq_groups, max_num_steps, stop_token_ids
def generate_inputs(num_inputs, block_size):
import random
random.seed(0)
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.utils import Counter
seq_group_counter = Counter()
seq_counter = Counter()
max_num_steps = {}
stop_token_ids = {}
seq_groups = []
for _ in range(num_inputs):
seq_group_id = next(seq_group_counter)
prompt_len = random.randint(16, 128)
max_num_steps[seq_group_id] = random.randint(32, 1024)
stop_token_ids[seq_group_id] = []
seqs = []
for _ in range(2):
seq_id = next(seq_counter)
seq = Sequence(seq_id, [0] * prompt_len, block_size=block_size)
seqs.append(seq)
seq_group = SequenceGroup(seq_group_id, seqs)
seq_groups.append(seq_group)
return seq_groups, max_num_steps, stop_token_ids
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment