Unverified Commit c3442c1f authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Refactor system architecture (#109)

parent 7297fa6f
...@@ -10,13 +10,17 @@ pip install -e . # This may take several minutes. ...@@ -10,13 +10,17 @@ pip install -e . # This may take several minutes.
## Test simple server ## Test simple server
```bash ```bash
# Single-GPU inference.
python examples/simple_server.py # --model <your_model>
# Multi-GPU inference (e.g., 2 GPUs).
ray start --head ray start --head
python simple_server.py python examples/simple_server.py -tp 2 # --model <your_model>
``` ```
The detailed arguments for `simple_server.py` can be found by: The detailed arguments for `simple_server.py` can be found by:
```bash ```bash
python simple_server.py --help python examples/simple_server.py --help
``` ```
## FastAPI server ## FastAPI server
...@@ -24,12 +28,12 @@ python simple_server.py --help ...@@ -24,12 +28,12 @@ python simple_server.py --help
To start the server: To start the server:
```bash ```bash
ray start --head ray start --head
python -m cacheflow.http_frontend.fastapi_frontend python -m cacheflow.entrypoints.fastapi_server # --model <your_model>
``` ```
To test the server: To test the server:
```bash ```bash
python -m cacheflow.http_frontend.test_cli_client python test_cli_client.py
``` ```
## Gradio web server ## Gradio web server
...@@ -55,7 +59,6 @@ Since LLaMA weight is not fully public, we cannot directly download the LLaMA we ...@@ -55,7 +59,6 @@ Since LLaMA weight is not fully public, we cannot directly download the LLaMA we
python src/transformers/models/llama/convert_llama_weights_to_hf.py \ python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path/llama-7b --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path/llama-7b
``` ```
Please make sure that `llama` is included in the output directory name.
2. For all the commands above, specify the model with `--model /output/path/llama-7b` to load the model. For example: 2. For all the commands above, specify the model with `--model /output/path/llama-7b` to load the model. For example:
```bash ```bash
python simple_server.py --model /output/path/llama-7b python simple_server.py --model /output/path/llama-7b
......
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import (
add_server_arguments,
create_server_configs_from_args,
initialize_server_from_args,
)
from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster
__all__ = [
"RequestOutput",
"SamplingParams",
"LLMServer",
"add_server_arguments",
"create_server_configs_from_args",
"initialize_server_from_args",
"initialize_cluster",
]
from typing import Optional
import torch
from transformers import AutoConfig, PretrainedConfig
class ModelConfig:
def __init__(
self,
model: str,
download_dir: Optional[str],
use_np_weights: bool,
use_dummy_weights: bool,
dtype: str,
seed: int,
) -> None:
self.model = model
self.download_dir = download_dir
self.use_np_weights = use_np_weights
self.use_dummy_weights = use_dummy_weights
self.seed = seed
self.hf_config: PretrainedConfig = AutoConfig.from_pretrained(model)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
total_num_attention_heads = self.hf_config.num_attention_heads
tensor_parallel_size = parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0:
raise ValueError(
f"Total number of attention heads ({total_num_attention_heads})"
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")
total_num_hidden_layers = self.hf_config.num_hidden_layers
pipeline_parallel_size = parallel_config.pipeline_parallel_size
if total_num_hidden_layers % pipeline_parallel_size != 0:
raise ValueError(
f"Total number of hidden layers ({total_num_hidden_layers}) "
"must be divisible by pipeline parallel size "
f"({pipeline_parallel_size}).")
def get_hidden_size(self) -> int:
return self.hf_config.hidden_size
def get_head_size(self) -> int:
# FIXME(woosuk): This may not be true for all models.
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
class CacheConfig:
def __init__(
self,
block_size: int,
gpu_memory_utilization: float,
swap_space: int,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space = swap_space
# Will be set after profiling.
self.num_gpu_blocks = None
self.num_cpu_blocks = None
class ParallelConfig:
def __init__(
self,
pipeline_parallel_size: int,
tensor_parallel_size: int,
use_ray: bool,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.use_ray = use_ray
self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:
self.use_ray = True
self._verify_args()
def _verify_args(self) -> None:
if self.pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is not supported yet.")
class SchedulerConfig:
def __init__(
self,
max_num_batched_tokens: int,
max_num_seqs: int,
) -> None:
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_seqs = max_num_seqs
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
"float": torch.float32,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
def _get_and_verify_dtype(
config: PretrainedConfig,
dtype: str,
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32
dtype = dtype.lower()
if dtype == "default":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
# Verify the dtype.
if torch_dtype != config_dtype:
if torch_dtype == torch.float32:
# Upcasting to float32 is allowed.
pass
elif config_dtype == torch.float32:
# Downcasting from float32 to float16 or bfloat16 is allowed.
pass
else:
# Casting between float16 and bfloat16 is not allowed.
raise ValueError(
f"Cannot use {torch_dtype} for {config_dtype} model.")
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
compute_capability = torch.cuda.get_device_capability()
if compute_capability[0] < 8:
gpu_name = torch.cuda.get_device_name()
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
f"{compute_capability[0]}.{compute_capability[1]}.")
return torch_dtype
...@@ -2,10 +2,10 @@ import enum ...@@ -2,10 +2,10 @@ import enum
import time import time
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from cacheflow.config import CacheConfig, SchedulerConfig
from cacheflow.core.block_manager import BlockSpaceManager from cacheflow.core.block_manager import BlockSpaceManager
from cacheflow.core.policy import PolicyFactory from cacheflow.core.policy import PolicyFactory
from cacheflow.logger import init_logger from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup, from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs, SequenceGroupMetadata, SequenceOutputs,
SequenceStatus) SequenceStatus)
...@@ -28,43 +28,53 @@ class PreemptionMode(enum.Enum): ...@@ -28,43 +28,53 @@ class PreemptionMode(enum.Enum):
RECOMPUTE = enum.auto() RECOMPUTE = enum.auto()
class SchedulerOutputs:
def __init__(
self,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
self.blocks_to_swap_in = blocks_to_swap_in
self.blocks_to_swap_out = blocks_to_swap_out
self.blocks_to_copy = blocks_to_copy
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
def is_empty(self) -> bool:
return (not self.blocks_to_swap_in
and not self.blocks_to_swap_out
and not self.blocks_to_copy)
class Scheduler: class Scheduler:
def __init__( def __init__(
self, self,
controllers: List, scheduler_config: SchedulerConfig,
block_size: int, cache_config: CacheConfig,
num_gpu_blocks: int,
num_cpu_blocks: int,
max_num_batched_tokens: int,
max_num_sequences: int,
log_stats: bool, log_stats: bool,
) -> None: ) -> None:
self.controllers = controllers self.scheduler_config = scheduler_config
self.block_size = block_size self.cache_config = cache_config
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_sequences = max_num_sequences
self.log_stats = log_stats self.log_stats = log_stats
# Instantiate the scheduling policy. # Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name='fcfs') self.policy = PolicyFactory.get_policy(policy_name='fcfs')
# Create the block space manager. # Create the block space manager.
self.block_manager = BlockSpaceManager( self.block_manager = BlockSpaceManager(
block_size=block_size, block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=self.cache_config.num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks,
) )
# Sequence groups in the WAITING state. # Sequence groups in the WAITING state.
self.waiting: List[SequenceGroup] = [] self.waiting: List[SequenceGroup] = []
# Sequence groups in the RUNNING state. # Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = [] self.running: List[SequenceGroup] = []
# Mapping: group_id -> num_steps. # Mapping: request_id -> num_steps.
self.num_steps: Dict[int, int] = {} self.num_steps: Dict[str, int] = {}
# Mapping: group_id -> sampling params.
self.sampling_params: Dict[int, SamplingParams] = {}
# Sequence groups in the SWAPPED state. # Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = [] self.swapped: List[SequenceGroup] = []
...@@ -72,18 +82,15 @@ class Scheduler: ...@@ -72,18 +82,15 @@ class Scheduler:
# List[timestamp, num_tokens] # List[timestamp, num_tokens]
self.num_input_tokens: List[Tuple[float, int]] = [] self.num_input_tokens: List[Tuple[float, int]] = []
def add_sequence_groups( def add_seq_group(self, seq_group: SequenceGroup) -> None:
self,
seq_groups: List[Tuple[SequenceGroup, SamplingParams]],
) -> None:
# Add sequence groups to the waiting queue. # Add sequence groups to the waiting queue.
for seq_group, sampling_params in seq_groups: assert seq_group.request_id not in self.num_steps
self.waiting.append(seq_group) self.waiting.append(seq_group)
self.sampling_params[seq_group.group_id] = sampling_params
def _schedule( def has_unfinished_seqs(self) -> bool:
self, return self.waiting or self.running or self.swapped
) -> Tuple[Dict[int, int], Dict[int, int], Dict[int, List[int]], List[int]]:
def _schedule(self) -> Tuple[SchedulerOutputs, List[int]]:
# Blocks that need to be swaped or copied before model execution. # Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {}
...@@ -136,8 +143,9 @@ class Scheduler: ...@@ -136,8 +143,9 @@ class Scheduler:
# The total number of sequences in the RUNNING state should not # The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences. # exceed the maximum number of sequences.
num_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
if len(self.running) + num_seqs > self.max_num_sequences: num_curr_seqs = len(self.running)
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
break break
seq_group = self.swapped.pop(0) seq_group = self.swapped.pop(0)
...@@ -151,7 +159,7 @@ class Scheduler: ...@@ -151,7 +159,7 @@ class Scheduler:
) )
# Join waiting sequences if possible. # Join waiting sequences if possible.
prompt_group_ids: List[int] = [] prompt_group_ids: List[str] = []
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly # NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
# prioritized over the sequence groups in the WAITING state. # prioritized over the sequence groups in the WAITING state.
# This is because we want to bound the amount of CPU memory taken by # This is because we want to bound the amount of CPU memory taken by
...@@ -172,25 +180,31 @@ class Scheduler: ...@@ -172,25 +180,31 @@ class Scheduler:
# If the number of batched tokens exceeds the limit, stop. # If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens = seq_group.seqs[0].get_len() num_prompt_tokens = seq_group.seqs[0].get_len()
if (num_batched_tokens + num_prompt_tokens if (num_batched_tokens + num_prompt_tokens
> self.max_num_batched_tokens): > self.scheduler_config.max_num_batched_tokens):
break break
# The total number of sequences in the RUNNING state should not # The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences. # exceed the maximum number of sequences.
num_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING) num_new_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
if len(self.running) + num_seqs > self.max_num_sequences: num_curr_seqs = len(self.running)
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
break break
seq_group = self.waiting.pop(0) seq_group = self.waiting.pop(0)
self._allocate(seq_group) self._allocate(seq_group)
self.running.append(seq_group) self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens num_batched_tokens += num_prompt_tokens
prompt_group_ids.append(seq_group.group_id) prompt_group_ids.append(seq_group.request_id)
scheduler_outputs = SchedulerOutputs(
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
if not self.log_stats: if not self.log_stats:
return (blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, return scheduler_outputs, prompt_group_ids
prompt_group_ids)
# TODO(woosuk): Move the below code to server.
now = time.time() now = time.time()
if num_batched_tokens > 0: if num_batched_tokens > 0:
self.num_input_tokens.append((now, num_batched_tokens)) self.num_input_tokens.append((now, num_batched_tokens))
...@@ -208,13 +222,16 @@ class Scheduler: ...@@ -208,13 +222,16 @@ class Scheduler:
else: else:
avg_throughput = 0.0 avg_throughput = 0.0
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks() num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
num_used_gpu_blocks = self.num_gpu_blocks - num_free_gpu_blocks num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
gpu_cache_usage = num_used_gpu_blocks / self.num_gpu_blocks gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
if self.num_cpu_blocks > 0:
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0:
num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks() num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
num_used_cpu_blocks = self.num_cpu_blocks - num_free_cpu_blocks num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / self.num_cpu_blocks cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else: else:
cpu_cache_usage = 0.0 cpu_cache_usage = 0.0
...@@ -225,27 +242,18 @@ class Scheduler: ...@@ -225,27 +242,18 @@ class Scheduler:
f"Pending: {len(self.waiting)} reqs, " f"Pending: {len(self.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
return scheduler_outputs, prompt_group_ids
return (blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
prompt_group_ids)
def step(self) -> List[SequenceGroup]:
# Schedule sequence groups. # Schedule sequence groups.
# This function call changes the internal states of the scheduler # This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting. # such as self.running, self.swapped, and self.waiting.
scheduler_output = self._schedule() scheduler_outputs, prompt_group_ids = self._schedule()
blocks_to_swap_in = scheduler_output[0]
blocks_to_swap_out = scheduler_output[1]
blocks_to_copy = scheduler_output[2]
prompt_group_ids = scheduler_output[3]
# Create input data structures. # Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
updated_seq_groups: List[SequenceGroup] = self.running.copy()
for seq_group in self.running: for seq_group in self.running:
group_id = seq_group.group_id is_prompt = seq_group.request_id in prompt_group_ids
is_prompt = group_id in prompt_group_ids
seq_data: Dict[int, List[SequenceData]] = {} seq_data: Dict[int, List[SequenceData]] = {}
block_tables: Dict[int, List[int]] = {} block_tables: Dict[int, List[int]] = {}
...@@ -255,36 +263,24 @@ class Scheduler: ...@@ -255,36 +263,24 @@ class Scheduler:
block_tables[seq_id] = self.block_manager.get_block_table(seq) block_tables[seq_id] = self.block_manager.get_block_table(seq)
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
group_id=group_id, request_id=seq_group.request_id,
is_prompt=is_prompt, is_prompt=is_prompt,
seq_data=seq_data, seq_data=seq_data,
sampling_params=self.sampling_params[group_id], sampling_params=seq_group.sampling_params,
block_tables=block_tables, block_tables=block_tables,
) )
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs
# Execute the first stage of the pipeline. def update(
if seq_group_metadata_list or blocks_to_swap_in or blocks_to_swap_out:
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.controllers[0].execute_stage(
seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
return updated_seq_groups
def post_step(
self, self,
seq_outputs: Dict[int, SequenceOutputs], seq_outputs: Dict[int, SequenceOutputs],
) -> None: ) -> List[SequenceGroup]:
# Update the running sequences and free blocks. # Update the running sequences and free blocks.
for seq_group in self.running: for seq_group in self.running:
group_id = seq_group.group_id request_id = seq_group.request_id
self.num_steps[group_id] += 1 self.num_steps[request_id] += 1
stop_token_ids = self.sampling_params[group_id].stop_token_ids stop_token_ids = seq_group.sampling_params.stop_token_ids
# Process beam search results before processing the next tokens. # Process beam search results before processing the next tokens.
for seq in seq_group.seqs: for seq in seq_group.seqs:
...@@ -316,12 +312,13 @@ class Scheduler: ...@@ -316,12 +312,13 @@ class Scheduler:
continue continue
# Check if the sequence has reached the maximum number of steps. # Check if the sequence has reached the maximum number of steps.
max_num_steps = self.sampling_params[group_id].max_tokens max_num_steps = seq_group.sampling_params.max_tokens
if self.num_steps[group_id] == max_num_steps: if self.num_steps[request_id] == max_num_steps:
self._free_seq(seq) self._free_seq(seq)
continue continue
# Update the running sequences. # Update the running sequences.
updated = self.running.copy()
running: List[SequenceGroup] = [] running: List[SequenceGroup] = []
for seq_group in self.running: for seq_group in self.running:
if seq_group.is_finished(): if seq_group.is_finished():
...@@ -329,13 +326,14 @@ class Scheduler: ...@@ -329,13 +326,14 @@ class Scheduler:
else: else:
running.append(seq_group) running.append(seq_group)
self.running = running self.running = running
return updated
def _allocate(self, seq_group: SequenceGroup) -> None: def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group) self.block_manager.allocate(seq_group)
for seq in seq_group.seqs: for seq in seq_group.seqs:
seq.status = SequenceStatus.RUNNING seq.status = SequenceStatus.RUNNING
if seq_group.group_id not in self.num_steps: if seq_group.request_id not in self.num_steps:
self.num_steps[seq_group.group_id] = 0 self.num_steps[seq_group.request_id] = 0
def _append_slot( def _append_slot(
self, self,
...@@ -410,9 +408,7 @@ class Scheduler: ...@@ -410,9 +408,7 @@ class Scheduler:
self.block_manager.free(seq) self.block_manager.free(seq)
def _free_seq_group(self, seq_group: SequenceGroup) -> None: def _free_seq_group(self, seq_group: SequenceGroup) -> None:
group_id = seq_group.group_id del self.num_steps[seq_group.request_id]
del self.num_steps[group_id]
del self.sampling_params[group_id]
def _swap_in( def _swap_in(
self, self,
......
import argparse
import random
from typing import List, Optional, Tuple
try:
import ray
except ImportError:
ray = None
import numpy as np
import torch
from cacheflow.core.scheduler import Scheduler
from cacheflow.frontend.simple_frontend import SimpleFrontend
from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroup
from cacheflow.worker.controller import Controller, DeviceID
logger = init_logger(__name__)
class Server:
def __init__(
self,
model: str,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
pipeline_parallel_size: int,
tensor_parallel_size: int,
block_size: int,
dtype: str,
seed: int,
swap_space: int,
gpu_memory_utilization: float,
max_num_batched_tokens: int,
max_num_sequences: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
all_stage_devices: List[List[DeviceID]],
use_ray: bool,
log_stats: bool,
):
logger.info(
"Initializing a server with config: "
f"model={model!r}, "
f"dtype={dtype}, "
f"use_dummy_weights={use_dummy_weights}, "
f"cache_dir={cache_dir!r}, "
f"use_np_cache={use_np_cache}, "
f"tensor_parallel_size={tensor_parallel_size}, "
f"seed={seed})"
)
self.num_nodes = num_nodes
self.num_devices_per_node = num_devices_per_node
self.world_size = pipeline_parallel_size * tensor_parallel_size
if not use_ray:
assert self.world_size == 1, (
"Only support single GPU without Ray.")
# Create a controller for each pipeline stage.
self.controllers: List[Controller] = []
for i in range(pipeline_parallel_size):
controller = Controller(
stage_id=i,
stage_devices=all_stage_devices[i],
world_size=self.world_size,
pipeline_parallel_size=pipeline_parallel_size,
tensor_parallel_size=tensor_parallel_size,
distributed_init_method=distributed_init_method,
model_name=model,
dtype=dtype,
seed=seed,
cache_dir=cache_dir,
use_dummy_weights=use_dummy_weights,
use_np_cache=use_np_cache,
max_num_batched_tokens=max_num_batched_tokens,
max_num_sequences=max_num_sequences,
use_ray=use_ray,
)
self.controllers.append(controller)
# Initialize cache engine.
all_worker_num_available_blocks = []
for controller in self.controllers:
all_worker_num_available_blocks.extend(
controller.get_num_available_blocks(
block_size, swap_space, gpu_memory_utilization)
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
self.num_gpu_blocks = np.min([b[0] for b in all_worker_num_available_blocks])
self.num_cpu_blocks = np.min([b[1] for b in all_worker_num_available_blocks])
logger.info(f'# GPU blocks: {self.num_gpu_blocks}, '
f'# CPU blocks: {self.num_cpu_blocks}')
for controller in self.controllers:
controller.init_cache_engine(block_size, self.num_gpu_blocks,
self.num_cpu_blocks)
# Create a scheduler.
self.scheduler = Scheduler(
controllers=self.controllers,
block_size=block_size,
num_gpu_blocks=self.num_gpu_blocks,
num_cpu_blocks=self.num_cpu_blocks,
max_num_batched_tokens=max_num_batched_tokens,
max_num_sequences=max_num_sequences,
log_stats=log_stats,
)
# Connect the controllers.
for i in range(len(self.controllers) - 1):
self.controllers[i].set_next(self.controllers[i + 1])
self.controllers[-1].set_next(self.scheduler)
def add_sequence_groups(
self,
sequence_groups: List[Tuple[SequenceGroup, SamplingParams]]
):
self.scheduler.add_sequence_groups(sequence_groups)
def step(self):
return self.scheduler.step()
def has_unfinished_requests(self):
return (self.scheduler.waiting or self.scheduler.running or
self.scheduler.swapped)
def initialize_cluster(
use_ray: bool = False,
address: Optional[str] = None,
pipeline_parallel_size: int = 1,
tensor_parallel_size: int = 1,
) -> Tuple[int, int, str, List[List[DeviceID]]]:
# Initialize cluster locally.
if not use_ray:
assert pipeline_parallel_size * tensor_parallel_size == 1, (
"Only support single GPU without Ray.")
num_nodes = 1
num_devices_per_node = torch.cuda.device_count()
port = random.randint(10000, 20000)
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method = f"tcp://localhost:{port}"
all_stage_devices = [[(0, None, 0)]]
return (num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices)
assert ray is not None, (
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
ray.init(address=address)
# Assume we have a uniform cluster that each node has the same number of
# GPUs for now.
valid_node_resources = []
num_devices_per_node = None
for node in ray.nodes():
if (not node['Alive']) or node['Resources']['GPU'] <= 0:
continue
if num_devices_per_node is None:
num_devices_per_node = node['Resources']['GPU']
else:
assert num_devices_per_node == node['Resources']['GPU'], (
"The number of GPUs per node is not uniform.")
for key in node['Resources']:
if key.startswith('node:'):
valid_node_resources.append(key)
num_nodes = len(valid_node_resources)
assert (pipeline_parallel_size * tensor_parallel_size
<= num_nodes * num_devices_per_node), (
"The number of required GPUs exceeds the total number of "
"available GPUs.")
if tensor_parallel_size >= num_devices_per_node:
assert tensor_parallel_size % num_devices_per_node == 0, (
"The number of tensor parallelism is not divisible by the "
"number of GPUs per node.")
else:
assert num_devices_per_node % tensor_parallel_size == 0, (
"The number of GPUs per node is not divisible by the number "
"of tensor parallelism.")
# Assign GPUs to pipeline stages.
rank = 0
current_node_id = 0
current_device_id = 0
distributed_init_method = None
all_stage_devices = []
for i in range(pipeline_parallel_size):
stage_devices = []
for j in range(tensor_parallel_size):
node_resource = valid_node_resources[current_node_id]
stage_devices.append((rank, node_resource, current_device_id))
if distributed_init_method is None:
ip = node_resource.split("node:")[-1]
port = random.randint(10000, 20000)
distributed_init_method = f"tcp://{ip}:{port}"
rank += 1
current_device_id += 1
if current_device_id >= num_devices_per_node:
current_node_id += 1
current_device_id = 0
all_stage_devices.append(stage_devices)
return (num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices)
_GiB = 1 << 30
def add_server_arguments(parser: argparse.ArgumentParser):
"""Shared arguments for CacheFlow servers."""
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--cache-dir', type=str, default=None,
help='cache dir to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument('--use-np-cache', action='store_true',
help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# TODO(woosuk): Support FP32 for debugging.
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'))
# Parallel arguments
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
parser.add_argument('--log-stats', action='store_true', help='log system statistics')
return parser
def process_server_arguments(args: argparse.Namespace):
"""Post process the parsed arguments."""
if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
args.use_ray = True
args.swap_space = args.swap_space * _GiB
args.max_num_sequences = min(args.max_num_sequences, args.max_num_batched_tokens)
return args
def init_local_server_and_frontend_with_arguments(args: argparse.Namespace):
# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_cluster(
use_ray=args.use_ray,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
# Create a server.
server = Server(
model=args.model,
cache_dir=args.cache_dir,
use_dummy_weights=args.use_dummy_weights,
use_np_cache=args.use_np_cache,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
gpu_memory_utilization=args.gpu_memory_utilization,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
use_ray=args.use_ray,
log_stats=args.log_stats,
)
# Create a frontend.
frontend = SimpleFrontend(
model_name=args.model,
block_size=args.block_size,
)
return server, frontend
...@@ -2,115 +2,66 @@ import argparse ...@@ -2,115 +2,66 @@ import argparse
import asyncio import asyncio
import json import json
import time import time
from typing import List, Dict, Optional from typing import Any, Dict
import uuid
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import ray import ray
import uvicorn import uvicorn
from cacheflow.core.server import (Server, add_server_arguments, from cacheflow.outputs import RequestOutput
process_server_arguments,
initialize_cluster)
from cacheflow.frontend.utils import get_tokenizer
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup from cacheflow.server.arg_utils import (
from cacheflow.utils import Counter add_server_arguments, create_server_configs_from_args)
from cacheflow.worker.controller import DeviceID from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app = FastAPI() app = FastAPI()
class FastAPIServer: class FastAPIServer:
def __init__(
self, def __init__(self, server_use_ray: bool, *args, **kwargs) -> None:
model: str,
cache_dir: Optional[str],
use_np_cache: bool,
pipeline_parallel_size: int,
tensor_parallel_size: int,
block_size: int,
dtype: str,
seed: int,
swap_space: int,
gpu_memory_utilization: float,
max_num_batched_tokens: int,
max_num_sequences: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
all_stage_devices: List[List[DeviceID]],
server_use_ray: bool,
log_stats: bool,
):
self.block_size = block_size
self.tokenizer = get_tokenizer(model)
self.seq_group_counter = Counter()
self.seq_counter = Counter()
if server_use_ray: if server_use_ray:
remote_server_class = ray.remote(num_cpus=0)(Server) remote_server_class = ray.remote(num_cpus=0)(LLMServer)
else: else:
remote_server_class = ray.remote(num_gpus=1)(Server) remote_server_class = ray.remote(num_gpus=1)(LLMServer)
self.server = remote_server_class.remote( self.server = remote_server_class.remote(*args, **kwargs)
model=model,
cache_dir=cache_dir, # Request id -> request output.
use_dummy_weights=False, self.request_outputs: Dict[str, RequestOutput] = {}
use_np_cache=use_np_cache, # Request id -> event to notify that there is new output.
pipeline_parallel_size=pipeline_parallel_size, self.request_events: Dict[str, asyncio.Event] = {}
tensor_parallel_size=tensor_parallel_size,
block_size=block_size,
dtype=dtype,
seed=seed,
swap_space=swap_space,
gpu_memory_utilization=gpu_memory_utilization,
max_num_batched_tokens=max_num_batched_tokens,
max_num_sequences=max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
use_ray=server_use_ray,
log_stats=log_stats,
)
self.running_seq_groups: Dict[int, SequenceGroup] = {}
self.sequence_group_events: Dict[int, asyncio.Event] = {}
self.is_server_running = False self.is_server_running = False
async def server_step(self): async def server_step(self):
self.is_server_running = True self.is_server_running = True
updated_seq_groups = await self.server.step.remote() request_outputs = await self.server.step.remote()
self.is_server_running = False self.is_server_running = False
# Notify the waiting coroutines that there are new outputs ready. # Notify the waiting coroutines that there are new outputs ready.
for seq_group in updated_seq_groups: for request_output in request_outputs:
group_id = seq_group.group_id request_id = request_output.request_id
self.running_seq_groups[group_id] = seq_group self.request_outputs[request_id] = request_output
self.sequence_group_events[group_id].set() self.request_events[request_id].set()
async def generate(self, request_dict: Dict): async def generate(self, request_dict: Dict[str, Any]):
# Preprocess the request. # Preprocess the request.
arrival_time = time.time()
prompt = request_dict.pop("prompt") prompt = request_dict.pop("prompt")
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
token_ids = self.tokenizer.encode(prompt)
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, token_ids, block_size=self.block_size)
seqs.append(seq)
arrival_time = time.time()
group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs, arrival_time)
# Create an event to notify us that there is new output from the # Create an event to notify us that there is new output from the
# cacheflow server. # cacheflow server.
group_event = asyncio.Event() request_id = str(uuid.uuid4().hex[:8])
self.running_seq_groups[group_id] = seq_group request_event = asyncio.Event()
self.sequence_group_events[group_id] = group_event self.request_events[request_id] = request_event
# Add the request into the cacheflow server's waiting queue. # Add the request into the cacheflow server's waiting queue.
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)]) await self.server.add_request.remote(
request_id, prompt, sampling_params, arrival_time=arrival_time)
# The cacheflow server does not have a background loop that keeps # The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking # processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests. # the server to process the requests.
...@@ -118,32 +69,35 @@ class FastAPIServer: ...@@ -118,32 +69,35 @@ class FastAPIServer:
# Kick the server if the server is not running. # Kick the server if the server is not running.
if not self.is_server_running: if not self.is_server_running:
await self.server_step() await self.server_step()
# Wait for new output. The group_event will be set in server_step # Wait for new output. The group_event will be set in server_step
# when there is new output available for the sequence group. # when there is new output available for the sequence group.
# Added a timeout to prevent deadlock. # Added a timeout to prevent deadlock.
try: try:
await asyncio.wait_for(group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK) await asyncio.wait_for(request_event.wait(),
timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
# Reset the event to wait for the next output. # Reset the event to wait for the next output.
group_event.clear() request_event.clear()
# Decode and return new outputs
seq_group = self.running_seq_groups[group_id] # Decode and return new outputs.
all_outputs = [] request_output = self.request_outputs[request_id]
for seq in seq_group.seqs: prompt = request_output.prompt
token_ids = seq.get_token_ids() text_outputs = [
output = self.tokenizer.decode(token_ids, skip_special_tokens=True) prompt + output.text
all_outputs.append(output) for output in request_output.outputs
]
ret = { ret = {
"text": all_outputs, "text": text_outputs,
"error": 0, "error": 0,
} }
yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
# Once finished, release the resources of the sequence group. # Once finished, release the resources of the sequence group.
if seq_group.is_finished(): if request_output.done:
del self.running_seq_groups[group_id] del self.request_outputs[request_id]
del self.sequence_group_events[group_id] del self.request_events[request_id]
# Kick the server if the server is not running. This is to # Kick the server if the server is not running. This is to
# prevent that there are still requests in server's waiting # prevent that there are still requests in server's waiting
# queue to be executed. # queue to be executed.
...@@ -164,38 +118,11 @@ if __name__ == "__main__": ...@@ -164,38 +118,11 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=10002) parser.add_argument("--port", type=int, default=10002)
parser = add_server_arguments(parser) parser = add_server_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args = process_server_arguments(args)
# TODO(zhuohan): Support pipeline parallelism. server_configs = create_server_configs_from_args(args)
assert args.pipeline_parallel_size == 1, ( parallel_config = server_configs[2]
'Pipeline parallelism is not supported yet.') distributed_init_method, stage_devices = initialize_cluster(parallel_config)
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_cluster(
use_ray=True,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
server = FastAPIServer( server = FastAPIServer(
model=args.model, args.use_ray, *server_configs, distributed_init_method, stage_devices)
cache_dir=args.cache_dir,
use_np_cache=args.use_np_cache,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
gpu_memory_utilization=args.gpu_memory_utilization,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
server_use_ray=args.use_ray,
log_stats=args.log_stats,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info") uvicorn.run(app, host=args.host, port=args.port, log_level="info")
import time
from typing import List, Optional, Tuple
from cacheflow.frontend.utils import get_tokenizer
from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.utils import Counter
logger = init_logger(__name__)
class SimpleFrontend:
def __init__(
self,
model_name: str,
block_size: int,
) -> None:
self.block_size = block_size
self.tokenizer = get_tokenizer(model_name)
self.seq_group_counter = Counter()
self.seq_counter = Counter()
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
def add_eos_token(self, sampling_params: SamplingParams) -> SamplingParams:
# Stop generation when we see an EOS token.
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
return sampling_params
def query(
self,
prompt: str,
sampling_params: SamplingParams,
) -> None:
token_ids = self.tokenizer.encode(prompt)
self._add_query(prompt, token_ids, sampling_params)
def _add_query(
self,
prompt: str,
token_ids: List[int],
sampling_params: SamplingParams,
arrival_time: Optional[float] = None,
) -> None:
if arrival_time is None:
arrival_time = time.time()
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, token_ids, block_size=self.block_size)
seqs.append(seq)
group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs, arrival_time)
self.inputs.append((seq_group, sampling_params))
def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]:
inputs = self.inputs
self.inputs = []
return inputs
def print_response(
self,
seq_group: SequenceGroup,
) -> None:
for seq in seq_group.seqs:
token_ids = seq.get_token_ids()
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
output = output.strip()
logger.info(f"Seq {seq.seq_id}: {output!r}")
from cacheflow.model_executor.input_metadata import InputMetadata from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.model_executor.model_loader import get_model from cacheflow.model_executor.model_loader import get_model
from cacheflow.model_executor.utils import (set_random_seed, from cacheflow.model_executor.utils import set_random_seed
get_cache_block_size)
__all__ = [ __all__ = [
"InputMetadata", "InputMetadata",
"get_cache_block_size",
"get_model", "get_model",
"set_random_seed", "set_random_seed",
] ]
...@@ -10,9 +10,9 @@ from cacheflow import cache_ops ...@@ -10,9 +10,9 @@ from cacheflow import cache_ops
from cacheflow import pos_encoding_ops from cacheflow import pos_encoding_ops
from cacheflow.model_executor.input_metadata import InputMetadata from cacheflow.model_executor.input_metadata import InputMetadata
_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256] _SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256]
class GPTCacheFlowAttention(nn.Module): class GPTCacheFlowAttention(nn.Module):
"""GPT-style multi-head attention. """GPT-style multi-head attention.
......
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoConfig, PretrainedConfig from transformers import PretrainedConfig
from cacheflow.config import ModelConfig
from cacheflow.model_executor.models import ( from cacheflow.model_executor.models import (
GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM) GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
from cacheflow.model_executor.utils import get_torch_dtype
from cacheflow.model_executor.weight_utils import initialize_dummy_weights from cacheflow.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes. # TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = { _MODEL_REGISTRY = {
"GPT2LMHeadModel": GPT2LMHeadModel, "GPT2LMHeadModel": GPT2LMHeadModel,
...@@ -19,6 +16,7 @@ _MODEL_REGISTRY = { ...@@ -19,6 +16,7 @@ _MODEL_REGISTRY = {
"OPTForCausalLM": OPTForCausalLM, "OPTForCausalLM": OPTForCausalLM,
} }
def _get_model_architecture(config: PretrainedConfig) -> nn.Module: def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
architectures = getattr(config, "architectures", []) architectures = getattr(config, "architectures", [])
for arch in architectures: for arch in architectures:
...@@ -30,51 +28,22 @@ def _get_model_architecture(config: PretrainedConfig) -> nn.Module: ...@@ -30,51 +28,22 @@ def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
) )
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype: def get_model(model_config: ModelConfig) -> nn.Module:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct model_class = _get_model_architecture(model_config.hf_config)
# because config.torch_dtype can be None. torch.set_default_dtype(model_config.dtype)
config_dtype = getattr(config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32
if dtype == "default":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
torch_dtype = get_torch_dtype(dtype)
if torch_dtype != config_dtype and config_dtype != torch.float32:
# TODO(woosuk): Allow using float16 for bfloat16 models and
# vice versa. Print a warning message and continue.
raise ValueError(
f"Cannot use {torch_dtype} for {config_dtype} model.")
return torch_dtype
def get_model(
model_name: str,
dtype: str,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
) -> nn.Module:
config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
torch.set_default_dtype(torch_dtype)
model_class = _get_model_architecture(config)
# Create a model instance. # Create a model instance.
# The weights will be initialized as empty tensors. # The weights will be initialized as empty tensors.
model = model_class(config) model = model_class(model_config.hf_config)
if use_dummy_weights: if model_config.use_dummy_weights:
model = model.cuda() model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights. # random values to the weights.
initialize_dummy_weights(model) initialize_dummy_weights(model)
else: else:
# Load the weights from the cached or downloaded files. # Load the weights from the cached or downloaded files.
model.load_weights(model_name, cache_dir, use_np_cache) model.load_weights(
model_config.model, model_config.download_dir,
model_config.use_np_weights)
model = model.cuda() model = model.cuda()
return model.eval(), torch_dtype return model.eval()
"""Utils for model executor.""" """Utils for model executor."""
import random import random
from typing import Union
import numpy as np import numpy as np
import torch import torch
...@@ -9,28 +8,6 @@ from cacheflow.model_executor.parallel_utils.parallel_state import model_paralle ...@@ -9,28 +8,6 @@ from cacheflow.model_executor.parallel_utils.parallel_state import model_paralle
from cacheflow.model_executor.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed from cacheflow.model_executor.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"float": torch.float,
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
if isinstance(dtype, str):
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
else:
torch_dtype = dtype
return torch_dtype
def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
torch_dtype = get_torch_dtype(dtype)
return torch.tensor([], dtype=torch_dtype).element_size()
def set_random_seed(seed: int) -> None: def set_random_seed(seed: int) -> None:
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
...@@ -40,15 +17,3 @@ def set_random_seed(seed: int) -> None: ...@@ -40,15 +17,3 @@ def set_random_seed(seed: int) -> None:
if model_parallel_is_initialized(): if model_parallel_is_initialized():
model_parallel_cuda_manual_seed(seed) model_parallel_cuda_manual_seed(seed)
def get_cache_block_size(block_size: int,
num_heads: int,
head_size: int,
num_layers: int,
dtype: str) -> int:
key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
dtype_size = get_dtype_size(dtype)
return dtype_size * total
from typing import Dict, List, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from cacheflow.sequence import SequenceGroup
class CompletionOutput:
def __init__(
self,
text: str,
token_ids: List[int],
cumulative_logprobs: float,
logprobs: List[Dict[int, float]],
) -> None:
self.text = text
self.token_ids = token_ids
self.cumulative_logprobs = cumulative_logprobs
self.logprobs = logprobs
def __repr__(self) -> str:
return (f"CompletionOutput(output={self.text!r}, "
f"token_ids={self.token_ids}, "
f"cumulative_logprobs={self.cumulative_logprobs}, "
f"logprobs={self.logprobs})")
class RequestOutput:
def __init__(
self,
request_id: int,
prompt: str,
prompt_token_ids: List[int],
outputs: List[CompletionOutput],
done: bool = False,
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.outputs = outputs
self.done = done
@staticmethod
def from_seq_group(
seq_group: SequenceGroup,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> "RequestOutput":
outputs: List[CompletionOutput] = []
seqs = seq_group.get_seqs()
for seq in seqs:
output_token_ids = seq.data.output_token_ids
output_str = tokenizer.decode(output_token_ids,
skip_special_tokens=True)
seq_logprobs = seq.data.cumulative_logprobs
logprobs = seq.output_logprobs
if seq_group.sampling_params.logprobs == 0:
# NOTE: We need to take care of this case because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
logprobs = {}
output = CompletionOutput(output_str, output_token_ids,
seq_logprobs, logprobs)
outputs.append(output)
# Every sequence in the sequence group should have the same prompt.
prompt = seqs[0].prompt
prompt_token_ids = seqs[0].data.prompt_token_ids
return RequestOutput(seq_group.request_id, prompt, prompt_token_ids,
outputs, seq_group.is_finished())
def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"outputs={self.outputs}, "
f"done={self.done})")
...@@ -116,4 +116,4 @@ class SamplingParams: ...@@ -116,4 +116,4 @@ class SamplingParams:
f"use_beam_search={self.use_beam_search}, " f"use_beam_search={self.use_beam_search}, "
f"stop_token_ids={self.stop_token_ids}, " f"stop_token_ids={self.stop_token_ids}, "
f"max_tokens={self.max_tokens}, " f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs}") f"logprobs={self.logprobs})")
...@@ -115,12 +115,14 @@ class SequenceGroup: ...@@ -115,12 +115,14 @@ class SequenceGroup:
def __init__( def __init__(
self, self,
group_id: int, request_id: str,
seqs: List[Sequence], seqs: List[Sequence],
sampling_params: SamplingParams,
arrival_time: float, arrival_time: float,
) -> None: ) -> None:
self.group_id = group_id self.request_id = request_id
self.seqs = seqs self.seqs = seqs
self.sampling_params = sampling_params
self.arrival_time = arrival_time self.arrival_time = arrival_time
def get_seqs( def get_seqs(
...@@ -145,21 +147,22 @@ class SequenceGroup: ...@@ -145,21 +147,22 @@ class SequenceGroup:
return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs) return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'SequenceGroup(group_id={self.group_id}, ' return (f"SequenceGroup(request_id={self.request_id}, "
f'num_seqs={len(self.seqs)})') f"sampling_params={self.sampling_params}, "
f"num_seqs={len(self.seqs)})")
class SequenceGroupMetadata: class SequenceGroupMetadata:
def __init__( def __init__(
self, self,
group_id: int, request_id: str,
is_prompt: bool, is_prompt: bool,
seq_data: Dict[int, SequenceData], # Seq id -> sequence data. seq_data: Dict[int, SequenceData], # Seq id -> sequence data.
sampling_params: SamplingParams, sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers. block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers.
) -> None: ) -> None:
self.group_id = group_id self.request_id = request_id
self.is_prompt = is_prompt self.is_prompt = is_prompt
self.seq_data = seq_data self.seq_data = seq_data
self.sampling_params = sampling_params self.sampling_params = sampling_params
......
import argparse
from typing import Tuple
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster
_GiB = 1 << 30
def add_server_arguments(parser: argparse.ArgumentParser):
"""Shared arguments for CacheFlow servers."""
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--download-dir', type=str, default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument('--use-np-weights', action='store_true',
help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'))
# Parallel arguments
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=4, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-seqs', type=int, default=256, help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics')
return parser
def create_server_configs_from_args(
args: argparse.Namespace,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Post-process the parsed arguments.
args.swap_space = args.swap_space * _GiB
args.max_num_seqs = min(args.max_num_seqs, args.max_num_batched_tokens)
# Initialize the configs.
model_config = ModelConfig(
args.model, args.download_dir, args.use_np_weights,
args.use_dummy_weights, args.dtype, args.seed)
cache_config = CacheConfig(args.block_size, args.gpu_memory_utilization,
args.swap_space)
parallel_config = ParallelConfig(args.pipeline_parallel_size,
args.tensor_parallel_size, args.use_ray)
scheduler_config = SchedulerConfig(args.max_num_batched_tokens,
args.max_num_seqs)
return model_config, cache_config, parallel_config, scheduler_config
def initialize_server_from_args(args: argparse.Namespace) -> LLMServer:
server_configs = create_server_configs_from_args(args)
parallel_config = server_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM server.
server = LLMServer(*server_configs, distributed_init_method, devices,
log_stats=not args.disable_log_stats)
return server
import time
from typing import Any, List, Optional
try:
import ray
except ImportError:
ray = None
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from cacheflow.core.scheduler import Scheduler
from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.utils import Counter
from cacheflow.worker.worker import Worker
logger = init_logger(__name__)
class LLMServer:
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
distributed_init_method: str,
stage_devices: List[List[Any]],
log_stats: bool = True,
) -> None:
logger.info(
"Initializing an LLM server with config: "
f"model={model_config.model!r}, "
f"dtype={model_config.dtype}, "
f"use_dummy_weights={model_config.use_dummy_weights}, "
f"download_dir={model_config.download_dir!r}, "
f"use_np_weights={model_config.use_np_weights}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"seed={model_config.seed})"
)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
self.cache_config = cache_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats
self._verify_args()
self.tokenizer = get_tokenizer(model_config.model)
self.seq_counter = Counter()
# Create the parallel GPU workers.
self.workers: List[Worker] = []
assert len(stage_devices) == 1, "Only support one stage for now."
for rank, node_resource, _ in stage_devices[0]:
worker_cls = Worker
if self.parallel_config.use_ray:
worker_cls = ray.remote(
num_cpus=0,
num_gpus=1,
resources={node_resource: 1e-5},
)(worker_cls).remote
worker = worker_cls(
model_config,
parallel_config,
scheduler_config,
rank,
distributed_init_method,
)
self.workers.append(worker)
# Profile the memory usage and initialize the cache.
self._init_cache()
# Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
def _init_cache(self) -> None:
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers(
"profile_num_available_blocks",
get_all_outputs=True,
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space,
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
# FIXME(woosuk): Change to debug log.
logger.info(f'# GPU blocks: {num_gpu_blocks}, '
f'# CPU blocks: {num_cpu_blocks}')
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
# Initialize the cache.
self._run_workers("init_cache_engine", cache_config=self.cache_config)
def add_request(
self,
request_id: str,
prompt: str,
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
) -> None:
if arrival_time is None:
arrival_time = time.time()
if prompt_token_ids is None:
prompt_token_ids = self.tokenizer.encode(prompt)
# Create the sequences.
block_size = self.cache_config.block_size
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seqs.append(seq)
# FIXME(woosuk)
# Add the EOS token to the stop token list.
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
# Create the sequence group.
seq_group = SequenceGroup(request_id, seqs, sampling_params,
arrival_time)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def has_unfinished_requests(self) -> bool:
return self.scheduler.has_unfinished_seqs()
def step(self) -> List[RequestOutput]:
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
# Nothing to do.
return []
# Execute the model.
output = self._run_workers(
"execute_model",
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)
# Update the scheduler.
updated_seq_groups = self.scheduler.update(output)
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in updated_seq_groups:
# TODO(woosuk): Batch-decode the outputs for speedup.
request_output = RequestOutput.from_seq_group(seq_group,
self.tokenizer)
request_outputs.append(request_output)
return request_outputs
def _run_workers(
self,
method: str,
get_all_outputs: bool = False,
*args,
**kwargs,
) -> Any:
all_outputs = []
for worker in self.workers:
executor = getattr(worker, method)
if self.parallel_config.use_ray:
executor = executor.remote
output = executor(*args, **kwargs)
all_outputs.append(output)
if self.parallel_config.use_ray:
all_outputs = ray.get(all_outputs)
if get_all_outputs:
return all_outputs
# Make sure all workers have the same results.
output = all_outputs[0]
for other_output in all_outputs[1:]:
assert output == other_output
return output
import random
from typing import List, Optional, Tuple
try:
import ray
except ImportError:
ray = None
from cacheflow.config import ParallelConfig
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id
def initialize_cluster(
parallel_config: ParallelConfig,
address: Optional[str] = None,
) -> Tuple[str, List[List[DeviceID]]]:
if not parallel_config.use_ray:
# Initialize cluster locally.
port = random.randint(10000, 20000)
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method = f"tcp://localhost:{port}"
all_stage_devices = [[(0, None, 0)]]
return distributed_init_method, all_stage_devices
if ray is None:
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
ray.init(address=address)
# Assume we have a uniform cluster that each node has the same number of
# GPUs for now.
valid_node_resources = []
num_devices_per_node = None
for node in ray.nodes():
if (not node['Alive']) or node['Resources']['GPU'] <= 0:
continue
if num_devices_per_node is None:
num_devices_per_node = node['Resources']['GPU']
else:
assert num_devices_per_node == node['Resources']['GPU'], (
"The number of GPUs per node is not uniform.")
for key in node['Resources']:
if key.startswith('node:'):
valid_node_resources.append(key)
# Verify the parallel config.
num_nodes = len(valid_node_resources)
if parallel_config.world_size > num_nodes * num_devices_per_node:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs.")
if parallel_config.tensor_parallel_size >= num_devices_per_node:
if parallel_config.tensor_parallel_size % num_devices_per_node != 0:
raise ValueError(
"The number of tensor parallelism is not divisible by the "
"number of GPUs per node.")
else:
if num_devices_per_node % parallel_config.tensor_parallel_size != 0:
raise ValueError(
"The number of GPUs per node is not divisible by the number "
"of tensor parallelism.")
# Assign GPUs to pipeline stages.
rank = 0
current_node_id = 0
current_device_id = 0
distributed_init_method = None
all_stage_devices = []
for _ in range(parallel_config.pipeline_parallel_size):
stage_devices = []
for _ in range(parallel_config.tensor_parallel_size):
node_resource = valid_node_resources[current_node_id]
stage_devices.append((rank, node_resource, current_device_id))
if distributed_init_method is None:
ip = node_resource.split("node:")[-1]
port = random.randint(10000, 20000)
distributed_init_method = f"tcp://{ip}:{port}"
rank += 1
current_device_id += 1
if current_device_id >= num_devices_per_node:
current_node_id += 1
current_device_id = 0
all_stage_devices.append(stage_devices)
return distributed_init_method, all_stage_devices
...@@ -3,7 +3,6 @@ from typing import Union ...@@ -3,7 +3,6 @@ from typing import Union
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer, from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast) PreTrainedTokenizerFast)
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [ _MODEL_TYPES_WITH_SLOW_TOKENIZER = [
# LLaMA fast tokenizer has a bug related to protobuf. # LLaMA fast tokenizer has a bug related to protobuf.
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554 # See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
......
...@@ -4,6 +4,7 @@ from typing import Dict, List, Tuple ...@@ -4,6 +4,7 @@ from typing import Dict, List, Tuple
import torch import torch
from cacheflow import cache_ops from cacheflow import cache_ops
from cacheflow.config import CacheConfig, ModelConfig, ParallelConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -18,27 +19,22 @@ class CacheEngine: ...@@ -18,27 +19,22 @@ class CacheEngine:
def __init__( def __init__(
self, self,
worker_id: int, cache_config: CacheConfig,
num_layers: int, model_config: ModelConfig,
num_heads: int, parallel_config: ParallelConfig,
head_size: int,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
dtype: torch.dtype,
) -> None: ) -> None:
if head_size % 16 != 0: self.cache_config = cache_config
raise ValueError( self.model_config = model_config
f'head_size ({head_size}) must be a multiple of 16.') self.parallel_config = parallel_config
self.worker_id = worker_id self.head_size = model_config.get_head_size()
self.num_layers = num_layers self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = num_heads self.num_heads = model_config.get_num_heads(parallel_config)
self.head_size = head_size self.dtype = model_config.dtype
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks self.block_size = cache_config.block_size
self.num_cpu_blocks = num_cpu_blocks self.num_gpu_blocks = cache_config.num_gpu_blocks
self.dtype = dtype self.num_cpu_blocks = cache_config.num_cpu_blocks
# Initialize the cache. # Initialize the cache.
self.gpu_cache = self.allocate_gpu_cache() self.gpu_cache = self.allocate_gpu_cache()
...@@ -48,7 +44,7 @@ class CacheEngine: ...@@ -48,7 +44,7 @@ class CacheEngine:
self.cache_stream = torch.cuda.Stream() self.cache_stream = torch.cuda.Stream()
assert self.cache_stream != torch.cuda.current_stream() assert self.cache_stream != torch.cuda.current_stream()
# Initialize the events for stream synchronization. # Initialize the events for stream synchronization.
self.events = [torch.cuda.Event() for _ in range(num_layers)] self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
def get_key_block_shape(self) -> Tuple[int, int, int, int]: def get_key_block_shape(self) -> Tuple[int, int, int, int]:
element_size = torch.tensor([], dtype=self.dtype).element_size() element_size = torch.tensor([], dtype=self.dtype).element_size()
...@@ -133,3 +129,23 @@ class CacheEngine: ...@@ -133,3 +129,23 @@ class CacheEngine:
value_caches = [value_cache for _, value_cache in self.gpu_cache] value_caches = [value_cache for _, value_cache in self.gpu_cache]
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU. # NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts) cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
@staticmethod
def get_cache_block_size(
block_size: int,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)
key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
dtype_size = _get_dtype_size(model_config.dtype)
return dtype_size * total
def _get_dtype_size(dtype: torch.dtype) -> int:
return torch.tensor([], dtype=dtype).element_size()
from typing import List, Optional, Tuple, Union
try:
import ray
except ImportError:
ray = None
from cacheflow.core.scheduler import Scheduler
from cacheflow.worker.worker import Worker
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id
class Controller:
def __init__(
self,
stage_id: int,
stage_devices: List[DeviceID],
world_size: int,
tensor_parallel_size: int,
pipeline_parallel_size: int,
distributed_init_method: str,
model_name: str,
dtype: str,
seed: int,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
max_num_batched_tokens: int,
max_num_sequences: int,
use_ray: bool,
) -> None:
self.stage_id = stage_id
self.stage_devices = stage_devices
self.model_name = model_name
self.use_ray = use_ray
# Which pipeline stage is this node assigned to?
self.is_first_stage = stage_id == 0
self.is_last_stage = False
self.workers: List[Worker] = []
for rank, node_resource, device_id in stage_devices:
if self.use_ray:
worker_cls = ray.remote(num_cpus=0,
num_gpus=1,
resources={node_resource: 1e-5})(Worker).remote
else:
worker_cls = Worker
worker = worker_cls(
model_name=model_name,
dtype=dtype,
seed=seed,
distributed_init_method=distributed_init_method,
rank=rank,
world_size=world_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
cache_dir=cache_dir,
use_dummy_weights=use_dummy_weights,
use_np_cache=use_np_cache,
max_num_batched_tokens=max_num_batched_tokens,
max_num_sequences=max_num_sequences,
)
self.workers.append(worker)
def get_num_available_blocks(self, block_size: int, cpu_swap_space: int,
gpu_memory_utilization: float) -> List[Tuple[int, int]]:
all_worker_results = []
for worker in self.workers:
executor = worker.get_num_available_blocks
if self.use_ray:
executor = executor.remote
result = executor(
block_size,
cpu_swap_space,
gpu_memory_utilization,
)
all_worker_results.append(result)
if self.use_ray:
all_worker_results = ray.get(all_worker_results)
return all_worker_results
def init_cache_engine(self, block_size: int, num_gpu_blocks: int,
num_cpu_blocks: int):
all_worker_futures = []
for worker in self.workers:
executor = worker.init_cache_engine
if self.use_ray:
executor = executor.remote
future = executor(
block_size,
num_gpu_blocks,
num_cpu_blocks,
)
all_worker_futures.append(future)
if self.use_ray:
ray.get(all_worker_futures)
def set_next(
self,
next_node: Union['Controller', 'Scheduler'],
) -> None:
self.next_node = next_node
self.is_last_stage = isinstance(next_node, Scheduler)
def execute_stage(self, *args, **kwargs) -> None:
all_outputs = []
for worker in self.workers:
executor = (worker.execute_stage.remote
if self.use_ray else worker.execute_stage)
output = executor(*args, **kwargs)
all_outputs.append(output)
if self.use_ray:
all_outputs = ray.get(all_outputs)
# Make sure all workers have the same results.
output = all_outputs[0]
for other_output in all_outputs[1:]:
assert output == other_output
if self.is_last_stage:
self.next_node.post_step(output)
else:
# TODO: Support pipeline parallelism.
assert False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment