Unverified Commit c3442c1f authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Refactor system architecture (#109)

parent 7297fa6f
"""A GPU worker class."""
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple
import torch
from cacheflow.model_executor import (get_model, get_cache_block_size,
InputMetadata, set_random_seed)
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from cacheflow.model_executor import get_model, InputMetadata, set_random_seed
from cacheflow.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel,
initialize_all_reduce_launcher,
get_tensor_model_parallel_world_size)
initialize_model_parallel, initialize_all_reduce_launcher)
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
SequenceOutputs)
......@@ -26,59 +25,46 @@ class Worker:
def __init__(
self,
model_name: str,
dtype: str,
seed: int,
distributed_init_method: str,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
rank: int,
world_size: int,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
max_num_batched_tokens: int,
max_num_sequences: int,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
distributed_init_method: str,
) -> None:
self.init_distributed_environment(distributed_init_method,
rank,
world_size,
tensor_parallel_size,
pipeline_parallel_size)
self.worker_id = rank
self.seed = seed
set_random_seed(self.seed)
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.rank = rank
self.distributed_init_method = distributed_init_method
# Initialize the distributed environment.
_init_distributed_environment(parallel_config, rank,
distributed_init_method)
# Initialize the model.
self.model, self.dtype = get_model(
model_name, dtype=dtype, cache_dir=cache_dir,
use_dummy_weights=use_dummy_weights, use_np_cache=use_np_cache)
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
self.max_num_batched_tokens = max_num_batched_tokens
set_random_seed(self.model_config.seed)
self.model = get_model(model_config)
initialize_all_reduce_launcher(
self.max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
self.max_num_sequences = max_num_sequences
self.num_layers = self.model.config.num_hidden_layers
assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
self.head_size = self.model.config.hidden_size // (self.num_heads * tensor_model_parallel_world_size)
# We reset the seed after initializing the model to ensure that
# the random state is not affected by the model initialization.
set_random_seed(seed)
# Uninitialized cache engine. Will be initialized with
self.scheduler_config.max_num_batched_tokens,
self.model_config.get_hidden_size(),
self.model_config.dtype,
)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
self.block_size = None
self.cache_engine = None
self.cache_events = None
self.gpu_cache = None
@torch.inference_mode()
def get_num_available_blocks(
self, block_size: int, cpu_swap_space: int,
gpu_memory_utilization: float) -> Tuple[int, int]:
def profile_num_available_blocks(
self,
block_size: int,
gpu_memory_utilization: float,
cpu_swap_space: int,
) -> Tuple[int, int]:
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
......@@ -90,14 +76,15 @@ class Worker:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99,
top_k=self.model.config.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
seqs = []
for group_id in range(self.max_num_sequences):
seq_len = (self.max_num_batched_tokens // self.max_num_sequences +
(group_id < self.max_num_batched_tokens %
self.max_num_sequences))
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
seq_data = SequenceData([0] * seq_len)
seq = SequenceGroupMetadata(
group_id=group_id,
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
......@@ -105,13 +92,14 @@ class Worker:
)
seqs.append(seq)
input_tokens, input_positions, input_metadata = self.prepare_inputs(seqs)
input_tokens, input_positions, input_metadata = self._prepare_inputs(seqs)
# Execute the model.
num_layers = self.model_config.get_num_layers(self.parallel_config)
self.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=[(None, None)] * self.num_layers,
kv_caches=[(None, None)] * num_layers,
input_metadata=input_metadata,
cache_events=None,
)
......@@ -121,53 +109,27 @@ class Worker:
torch.cuda.synchronize()
peak_memory = torch.cuda.max_memory_allocated()
total_gpu_memory = get_gpu_memory()
cache_block_size = get_cache_block_size(block_size, self.num_heads,
self.head_size, self.num_layers,
self.dtype)
cache_block_size = CacheEngine.get_cache_block_size(
block_size, self.model_config, self.parallel_config)
num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization
- peak_memory) // cache_block_size)
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
torch.cuda.empty_cache()
# Reset the seed to ensure that the model output is not affected by
# the profiling.
set_random_seed(self.seed)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
return num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, block_size: int, num_gpu_blocks: int,
num_cpu_blocks: int):
self.block_size = block_size
def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
self.block_size = cache_config.block_size
self.cache_engine = CacheEngine(
worker_id=self.worker_id,
num_layers=self.num_layers,
num_heads=self.num_heads,
head_size=self.head_size,
block_size=self.block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
dtype=self.dtype,
)
self.cache_config, self.model_config, self.parallel_config)
self.cache_events = self.cache_engine.events
self.gpu_cache = self.cache_engine.gpu_cache
def init_distributed_environment(self,
distributed_init_method: str,
rank: int,
world_size: int,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1) -> None:
"""Initialize the distributed environment."""
torch.distributed.init_process_group(
backend='nccl',
init_method=distributed_init_method,
world_size=world_size,
rank=rank,
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(tensor_parallel_size,
pipeline_parallel_size)
def prepare_inputs(
def _prepare_inputs(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
......@@ -284,7 +246,7 @@ class Worker:
return tokens_tensor, positions_tensor, input_metadata
@torch.inference_mode()
def execute_stage(
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
......@@ -316,7 +278,7 @@ class Worker:
return {}
# Prepare input tensors.
input_tokens, input_positions, input_metadata = self.prepare_inputs(
input_tokens, input_positions, input_metadata = self._prepare_inputs(
seq_group_metadata_list)
# Execute the model.
......@@ -330,6 +292,24 @@ class Worker:
return output
def _init_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: str,
) -> None:
"""Initialize the distributed environment."""
torch.distributed.init_process_group(
backend="nccl",
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
return x + [0] * ((-len(x)) % multiple_of)
......
import argparse
import uuid
from cacheflow import (add_server_arguments, initialize_server_from_args,
SamplingParams)
def main(args: argparse.Namespace):
# Initialize the server.
server = initialize_server_from_args(args)
# Test the following prompts.
test_prompts = [
("A robot may not injure a human being", SamplingParams()),
("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?",
SamplingParams(n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1)),
("It is only with the heart that one can see rightly",
SamplingParams(n=3, use_beam_search=True, temperature=0.0)),
]
# Run the server.
while True:
# To test iteration-level scheduling, we add one request at each step.
if test_prompts:
prompt, sampling_params = test_prompts.pop(0)
request_id = str(uuid.uuid4().hex[:8])
server.add_request(request_id, prompt, sampling_params)
request_outputs = server.step()
for request_output in request_outputs:
if request_output.done:
print(request_output)
if not (server.has_unfinished_requests() or test_prompts):
break
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Simple CacheFlow server.')
parser = add_server_arguments(parser)
args = parser.parse_args()
main(args)
import argparse
from cacheflow.core.server import (
add_server_arguments, process_server_arguments,
init_local_server_and_frontend_with_arguments)
from cacheflow.sampling_params import SamplingParams
def main(args: argparse.Namespace):
server, frontend = init_local_server_and_frontend_with_arguments(args)
# Test the following inputs.
test_inputs = [
("A robot may not injure a human being", {}), # Use default parameters.
("To be or not to be,", {"temperature": 0.8, "top_k": 5, "presence_penalty": 0.2}),
("What is the meaning of life?", {"n": 2, "temperature": 0.8, "top_p": 0.95, "frequency_penalty": 0.1}),
("It is only with the heart that one can see rightly", {"n": 3, "use_beam_search": True, "temperature": 0.0}),
]
while True:
if test_inputs:
text, sampling_params_dict = test_inputs.pop(0)
sampling_params = SamplingParams(**sampling_params_dict)
sampling_params = frontend.add_eos_token(sampling_params)
frontend.query(text, sampling_params)
server.add_sequence_groups(frontend.get_inputs())
updated_seq_groups = server.step()
for seq_group in updated_seq_groups:
if seq_group.is_finished():
frontend.print_response(seq_group)
if not (server.has_unfinished_requests() or test_inputs):
break
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
parser = add_server_arguments(parser)
args = parser.parse_args()
args = process_server_arguments(args)
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment