Unverified Commit 655a5e48 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Introduce LLM class for offline inference (#115)

parent f746ced0
from cacheflow.entrypoints.llm import LLM
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ( from cacheflow.server.arg_utils import ServerArgs
add_server_arguments,
create_server_configs_from_args,
initialize_server_from_args,
)
from cacheflow.server.llm_server import LLMServer from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster from cacheflow.server.ray_utils import initialize_cluster
__all__ = [ __all__ = [
"RequestOutput", "LLM",
"SamplingParams", "SamplingParams",
"RequestOutput",
"LLMServer", "LLMServer",
"add_server_arguments", "ServerArgs",
"create_server_configs_from_args",
"initialize_server_from_args",
"initialize_cluster", "initialize_cluster",
] ]
...@@ -3,6 +3,8 @@ from typing import Optional ...@@ -3,6 +3,8 @@ from typing import Optional
import torch import torch
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
_GiB = 1 << 30
class ModelConfig: class ModelConfig:
...@@ -70,7 +72,7 @@ class CacheConfig: ...@@ -70,7 +72,7 @@ class CacheConfig:
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space = swap_space self.swap_space_bytes = swap_space * _GiB
# Will be set after profiling. # Will be set after profiling.
self.num_gpu_blocks = None self.num_gpu_blocks = None
...@@ -138,6 +140,8 @@ def _get_and_verify_dtype( ...@@ -138,6 +140,8 @@ def _get_and_verify_dtype(
else: else:
torch_dtype = config_dtype torch_dtype = config_dtype
else: else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
# Verify the dtype. # Verify the dtype.
......
...@@ -12,8 +12,7 @@ import uvicorn ...@@ -12,8 +12,7 @@ import uvicorn
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ( from cacheflow.server.arg_utils import ServerArgs
add_server_arguments, create_server_configs_from_args)
from cacheflow.server.llm_server import LLMServer from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster from cacheflow.server.ray_utils import initialize_cluster
...@@ -116,10 +115,10 @@ if __name__ == "__main__": ...@@ -116,10 +115,10 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=10002) parser.add_argument("--port", type=int, default=10002)
parser = add_server_arguments(parser) parser = ServerArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
server_configs = create_server_configs_from_args(args) server_configs = ServerArgs.from_cli_args(args).create_server_configs()
parallel_config = server_configs[2] parallel_config = server_configs[2]
distributed_init_method, stage_devices = initialize_cluster(parallel_config) distributed_init_method, stage_devices = initialize_cluster(parallel_config)
......
from typing import List, Optional
from tqdm import tqdm
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMServer
from cacheflow.utils import Counter
class LLM:
def __init__(
self,
model: str,
tensor_parallel_size: int = 1,
dtype: str = "default",
seed: int = 0,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
server_args = ServerArgs(
model=model,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
seed=seed,
**kwargs,
)
self.llm_server = LLMServer.from_server_args(server_args)
self.request_counter = Counter()
def generate(
self,
prompts: List[str],
sampling_params: Optional[SamplingParams] = None,
use_tqdm: bool = True,
) -> List[RequestOutput]:
if sampling_params is None:
sampling_params = SamplingParams()
# Initialize tqdm.
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Processed prompts")
# Add requests to the server.
for prompt in prompts:
request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params)
# Run the server.
outputs: List[RequestOutput] = []
while self.llm_server.has_unfinished_requests():
step_outputs = self.llm_server.step()
for output in step_outputs:
if output.done:
outputs.append(output)
if use_tqdm:
pbar.update(1)
if use_tqdm:
pbar.close()
return outputs
...@@ -35,7 +35,7 @@ class RequestOutput: ...@@ -35,7 +35,7 @@ class RequestOutput:
prompt: str, prompt: str,
prompt_token_ids: List[int], prompt_token_ids: List[int],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],
done: bool = False, done: bool,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
...@@ -43,8 +43,8 @@ class RequestOutput: ...@@ -43,8 +43,8 @@ class RequestOutput:
self.outputs = outputs self.outputs = outputs
self.done = done self.done = done
@staticmethod @classmethod
def from_seq_group(seq_group: SequenceGroup) -> "RequestOutput": def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
# Get the top-n sequences. # Get the top-n sequences.
n = seq_group.sampling_params.n n = seq_group.sampling_params.n
seqs = seq_group.get_seqs() seqs = seq_group.get_seqs()
...@@ -70,8 +70,8 @@ class RequestOutput: ...@@ -70,8 +70,8 @@ class RequestOutput:
# Every sequence in the sequence group should have the same prompt. # Every sequence in the sequence group should have the same prompt.
prompt = top_n_seqs[0].prompt prompt = top_n_seqs[0].prompt
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
return RequestOutput(seq_group.request_id, prompt, prompt_token_ids, return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
outputs, seq_group.is_finished()) seq_group.is_finished())
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, " return (f"RequestOutput(request_id={self.request_id}, "
......
import argparse import argparse
from typing import Tuple import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig, from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster
_GiB = 1 << 30
@dataclass
class ServerArgs:
model: str
download_dir: Optional[str] = None
use_np_weights: bool = False
use_dummy_weights: bool = False
dtype: str = "default"
seed: int = 0
use_ray: bool = False
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
block_size: int = 16
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.95
max_num_batched_tokens: int = 2560
max_num_seqs: int = 256
disable_log_stats: bool = False
def add_server_arguments(parser: argparse.ArgumentParser): def __post_init__(self):
"""Shared arguments for CacheFlow servers.""" self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
return _add_server_arguments(parser)
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
server_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return server_args
def create_server_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs.
model_config = ModelConfig(
self.model, self.download_dir, self.use_np_weights,
self.use_dummy_weights, self.dtype, self.seed)
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
self.swap_space)
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.use_ray)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs)
return model_config, cache_config, parallel_config, scheduler_config
def _add_server_arguments(
parser: argparse.ArgumentParser,
)-> argparse.ArgumentParser:
"""Shared CLI arguments for CacheFlow servers."""
# Model arguments # Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name') parser.add_argument('--model', type=str, default='facebook/opt-125m',
parser.add_argument('--download-dir', type=str, default=None, help='name or path of the huggingface model to use')
parser.add_argument('--download-dir', type=str,
default=ServerArgs.download_dir,
help='directory to download and load the weights, ' help='directory to download and load the weights, '
'default to the default cache dir of huggingface') 'default to the default cache dir of huggingface')
parser.add_argument('--use-np-weights', action='store_true', parser.add_argument('--use-np-weights', action='store_true',
help='save a numpy copy of model weights for faster loading') help='save a numpy copy of model weights for faster '
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights') 'loading. This can increase the disk usage by up '
'to 2x.')
parser.add_argument('--use-dummy-weights', action='store_true',
help='use dummy values for model weights')
# TODO(woosuk): Support FP32. # TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'], parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. ' help=('data type for model weights and activations. '
'The "default" option will use FP16 precision ' 'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision ' 'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')) 'for BF16 models.'))
# Parallel arguments # Parallel arguments
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU') parser.add_argument('--use-ray', action='store_true',
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') help='use Ray for distributed serving, will be '
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas') 'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
default=ServerArgs.pipeline_parallel_size,
help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
default=ServerArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size') parser.add_argument('--block-size', type=int, default=ServerArgs.block_size,
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request). # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--seed', type=int, default=ServerArgs.seed,
parser.add_argument('--swap-space', type=int, default=4, help='CPU swap space size (GiB) per GPU') help='random seed')
parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor') parser.add_argument('--swap-space', type=int, default=ServerArgs.swap_space,
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration') help='CPU swap space size (GiB) per GPU')
parser.add_argument('--max-num-seqs', type=int, default=256, help='maximum number of sequences per iteration') parser.add_argument('--gpu-memory-utilization', type=float,
parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') default=ServerArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for the '
'model executor')
parser.add_argument('--max-num-batched-tokens', type=int,
default=ServerArgs.max_num_batched_tokens,
help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-seqs', type=int,
default=ServerArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true',
help='disable logging statistics')
return parser return parser
def create_server_configs_from_args(
args: argparse.Namespace,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Post-process the parsed arguments.
args.swap_space = args.swap_space * _GiB
args.max_num_seqs = min(args.max_num_seqs, args.max_num_batched_tokens)
# Initialize the configs.
model_config = ModelConfig(
args.model, args.download_dir, args.use_np_weights,
args.use_dummy_weights, args.dtype, args.seed)
cache_config = CacheConfig(args.block_size, args.gpu_memory_utilization,
args.swap_space)
parallel_config = ParallelConfig(args.pipeline_parallel_size,
args.tensor_parallel_size, args.use_ray)
scheduler_config = SchedulerConfig(args.max_num_batched_tokens,
args.max_num_seqs)
return model_config, cache_config, parallel_config, scheduler_config
def initialize_server_from_args(args: argparse.Namespace) -> LLMServer:
server_configs = create_server_configs_from_args(args)
parallel_config = server_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM server.
server = LLMServer(*server_configs, distributed_init_method, devices,
log_stats=not args.disable_log_stats)
return server
...@@ -12,6 +12,8 @@ from cacheflow.core.scheduler import Scheduler ...@@ -12,6 +12,8 @@ from cacheflow.core.scheduler import Scheduler
from cacheflow.logger import init_logger from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import initialize_cluster
from cacheflow.server.tokenizer_utils import get_tokenizer from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
from cacheflow.utils import Counter from cacheflow.utils import Counter
...@@ -30,7 +32,7 @@ class LLMServer: ...@@ -30,7 +32,7 @@ class LLMServer:
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
distributed_init_method: str, distributed_init_method: str,
stage_devices: List[List[Any]], stage_devices: List[List[Any]],
log_stats: bool = True, log_stats: bool,
) -> None: ) -> None:
logger.info( logger.info(
"Initializing an LLM server with config: " "Initializing an LLM server with config: "
...@@ -90,7 +92,7 @@ class LLMServer: ...@@ -90,7 +92,7 @@ class LLMServer:
get_all_outputs=True, get_all_outputs=True,
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization, gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space, cpu_swap_space=self.cache_config.swap_space_bytes,
) )
# Since we use a shared centralized controller, we take the minimum # Since we use a shared centralized controller, we take the minimum
...@@ -107,6 +109,18 @@ class LLMServer: ...@@ -107,6 +109,18 @@ class LLMServer:
# Initialize the cache. # Initialize the cache.
self._run_workers("init_cache_engine", cache_config=self.cache_config) self._run_workers("init_cache_engine", cache_config=self.cache_config)
@classmethod
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
# Create the server configs.
server_configs = server_args.create_server_configs()
parallel_config = server_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM server.
server = cls(*server_configs, distributed_init_method, devices,
log_stats=not server_args.disable_log_stats)
return server
def add_request( def add_request(
self, self,
request_id: str, request_id: str,
......
from cacheflow import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="facebook/opt-125m")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
import argparse import argparse
import uuid import uuid
from cacheflow import (add_server_arguments, initialize_server_from_args, from cacheflow import ServerArgs, LLMServer, SamplingParams
SamplingParams)
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
# Initialize the server. # Parse the CLI argument and initialize the server.
server = initialize_server_from_args(args) server_args = ServerArgs.from_cli_args(args)
server = LLMServer.from_server_args(server_args)
# Test the following prompts. # Test the following prompts.
test_prompts = [ test_prompts = [
...@@ -39,6 +39,6 @@ def main(args: argparse.Namespace): ...@@ -39,6 +39,6 @@ def main(args: argparse.Namespace):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Simple CacheFlow server.') parser = argparse.ArgumentParser(description='Simple CacheFlow server.')
parser = add_server_arguments(parser) parser = ServerArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment