Unverified Commit 0b98ba15 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Change the name to vLLM (#150)

parent e5464ee4
...@@ -3,8 +3,8 @@ from typing import Optional ...@@ -3,8 +3,8 @@ from typing import Optional
import torch import torch
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
from cacheflow.logger import init_logger from vllm.logger import init_logger
from cacheflow.utils import get_cpu_memory from vllm.utils import get_cpu_memory
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -87,7 +87,7 @@ class CacheConfig: ...@@ -87,7 +87,7 @@ class CacheConfig:
Args: Args:
block_size: Size of a cache block in number of tokens. block_size: Size of a cache block in number of tokens.
gpu_memory_utilization: Fraction of GPU memory to use for the gpu_memory_utilization: Fraction of GPU memory to use for the
CacheFlow execution. vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB). swap_space: Size of the CPU swap space per GPU (in GiB).
""" """
def __init__( def __init__(
......
"""A block manager that manages token blocks.""" """A block manager that manages token blocks."""
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
from cacheflow.block import PhysicalTokenBlock from vllm.block import PhysicalTokenBlock
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from cacheflow.utils import Device from vllm.utils import Device
class BlockAllocator: class BlockAllocator:
......
from typing import List from typing import List
from cacheflow.sequence import SequenceGroup from vllm.sequence import SequenceGroup
class Policy: class Policy:
......
...@@ -2,13 +2,13 @@ import enum ...@@ -2,13 +2,13 @@ import enum
import time import time
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from cacheflow.config import CacheConfig, SchedulerConfig from vllm.config import CacheConfig, SchedulerConfig
from cacheflow.core.block_manager import BlockSpaceManager from vllm.core.block_manager import BlockSpaceManager
from cacheflow.core.policy import PolicyFactory from vllm.core.policy import PolicyFactory
from cacheflow.logger import init_logger from vllm.logger import init_logger
from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs, SequenceGroupMetadata, SequenceOutputs,
SequenceStatus) SequenceStatus)
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -3,13 +3,13 @@ import dataclasses ...@@ -3,13 +3,13 @@ import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
@dataclass @dataclass
class EngineArgs: class EngineArgs:
"""Arguments for CacheFlow engine.""" """Arguments for vLLM engine."""
model: str model: str
download_dir: Optional[str] = None download_dir: Optional[str] = None
use_np_weights: bool = False use_np_weights: bool = False
...@@ -33,7 +33,7 @@ class EngineArgs: ...@@ -33,7 +33,7 @@ class EngineArgs:
def add_cli_args( def add_cli_args(
parser: argparse.ArgumentParser, parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser: ) -> argparse.ArgumentParser:
"""Shared CLI arguments for CacheFlow engine.""" """Shared CLI arguments for vLLM engine."""
# Model arguments # Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', parser.add_argument('--model', type=str, default='facebook/opt-125m',
help='name or path of the huggingface model to use') help='name or path of the huggingface model to use')
...@@ -118,7 +118,7 @@ class EngineArgs: ...@@ -118,7 +118,7 @@ class EngineArgs:
@dataclass @dataclass
class AsyncEngineArgs(EngineArgs): class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous CacheFlow engine.""" """Arguments for asynchronous vLLM engine."""
engine_use_ray: bool = False engine_use_ray: bool = False
disable_log_requests: bool = False disable_log_requests: bool = False
......
...@@ -2,12 +2,12 @@ import asyncio ...@@ -2,12 +2,12 @@ import asyncio
import time import time
from typing import Dict, List, Optional from typing import Dict, List, Optional
from cacheflow.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from cacheflow.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from cacheflow.engine.ray_utils import initialize_cluster, ray from vllm.engine.ray_utils import initialize_cluster, ray
from cacheflow.logger import init_logger from vllm.logger import init_logger
from cacheflow.outputs import RequestOutput from vllm.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -104,7 +104,7 @@ class AsyncLLMEngine: ...@@ -104,7 +104,7 @@ class AsyncLLMEngine:
arrival_time = time.time() arrival_time = time.time()
# Create an event to notify us that there is new output from the # Create an event to notify us that there is new output from the
# cacheflow engine. # vLLM engine.
request_event = asyncio.Event() request_event = asyncio.Event()
self.request_events[request_id] = request_event self.request_events[request_id] = request_event
...@@ -114,7 +114,7 @@ class AsyncLLMEngine: ...@@ -114,7 +114,7 @@ class AsyncLLMEngine:
f"sampling params: {sampling_params}, " f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.") f"prompt token ids: {prompt_token_ids}.")
# Add the request into the cacheflow engine's waiting queue. # Add the request into the vLLM engine's waiting queue.
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.add_request.remote( await self.engine.add_request.remote(
request_id, prompt, sampling_params, request_id, prompt, sampling_params,
...@@ -126,7 +126,7 @@ class AsyncLLMEngine: ...@@ -126,7 +126,7 @@ class AsyncLLMEngine:
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time) arrival_time=arrival_time)
# The cacheflow engine does not have a background loop that keeps # The vLLM engine does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking # processing incoming requests. Therefore, we need to keep kicking
# the engine to process the requests. # the engine to process the requests.
while True: while True:
......
import time import time
from typing import Any, List, Optional from typing import Any, List, Optional
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from cacheflow.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from cacheflow.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from cacheflow.engine.ray_utils import DeviceID, initialize_cluster, ray from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray
from cacheflow.engine.tokenizer_utils import (detokenize_incrementally, from vllm.engine.tokenizer_utils import detokenize_incrementally, get_tokenizer
get_tokenizer) from vllm.logger import init_logger
from cacheflow.logger import init_logger from vllm.outputs import RequestOutput
from cacheflow.outputs import RequestOutput from vllm.sampling_params import SamplingParams
from cacheflow.sampling_params import SamplingParams from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Counter
from cacheflow.utils import Counter from vllm.worker.worker import Worker
from cacheflow.worker.worker import Worker
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -21,7 +20,7 @@ logger = init_logger(__name__) ...@@ -21,7 +20,7 @@ logger = init_logger(__name__)
class LLMEngine: class LLMEngine:
"""An LLM engine that receives requests and generates texts. """An LLM engine that receives requests and generates texts.
This is the main class for the CacheFlow LLM engine. It receives requests This is the main class for the vLLM engine. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes space allocated for intermediate states (aka KV cache). This class utilizes
......
...@@ -6,7 +6,7 @@ try: ...@@ -6,7 +6,7 @@ try:
except ImportError: except ImportError:
ray = None ray = None
from cacheflow.config import ParallelConfig from vllm.config import ParallelConfig
DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), device id DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), device id
......
...@@ -3,7 +3,7 @@ from typing import List, Tuple, Union ...@@ -3,7 +3,7 @@ from typing import List, Tuple, Union
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer, from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast) PreTrainedTokenizerFast)
from cacheflow.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -6,10 +6,10 @@ from fastapi import BackgroundTasks, FastAPI, Request ...@@ -6,10 +6,10 @@ from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
import uvicorn import uvicorn
from cacheflow.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from cacheflow.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from cacheflow.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from cacheflow.utils import random_uuid from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
......
...@@ -3,11 +3,11 @@ from typing import List, Optional, Union ...@@ -3,11 +3,11 @@ from typing import List, Optional, Union
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from cacheflow.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from cacheflow.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from cacheflow.outputs import RequestOutput from vllm.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from cacheflow.utils import Counter from vllm.utils import Counter
class LLM: class LLM:
......
...@@ -13,17 +13,17 @@ from fastapi.middleware.cors import CORSMiddleware ...@@ -13,17 +13,17 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
import uvicorn import uvicorn
from cacheflow.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from cacheflow.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from cacheflow.engine.tokenizer_utils import get_tokenizer from vllm.engine.tokenizer_utils import get_tokenizer
from cacheflow.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse,
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo) LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
from cacheflow.logger import init_logger from vllm.logger import init_logger
from cacheflow.outputs import RequestOutput from vllm.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from cacheflow.utils import random_uuid from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
...@@ -93,11 +93,11 @@ async def create_completion(raw_request: Request): ...@@ -93,11 +93,11 @@ async def create_completion(raw_request: Request):
for the API specification. This API mimics the OpenAI Completion API. for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features: NOTE: Currently we do not support the following features:
- echo (since the cacheflow engine does not currently support - echo (since the vLLM engine does not currently support
getting the logprobs of prompt tokens) getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support - suffix (the language models we currently support do not support
suffix) suffix)
- logit_bias (to be supported in cacheflow engine) - logit_bias (to be supported by vLLM engine)
""" """
request = CompletionRequest(**await raw_request.json()) request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}") logger.info(f"Received completion request: {request}")
...@@ -107,7 +107,7 @@ async def create_completion(raw_request: Request): ...@@ -107,7 +107,7 @@ async def create_completion(raw_request: Request):
return error_check_ret return error_check_ret
if request.echo: if request.echo:
# We do not support echo since the cacheflow engine does not # We do not support echo since the vLLM engine does not
# currently support getting the logprobs of prompt tokens. # currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported") "echo is not currently supported")
...@@ -118,7 +118,7 @@ async def create_completion(raw_request: Request): ...@@ -118,7 +118,7 @@ async def create_completion(raw_request: Request):
"suffix is not currently supported") "suffix is not currently supported")
if request.logit_bias is not None: if request.logit_bias is not None:
# TODO: support logit_bias in cacheflow engine. # TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported") "logit_bias is not currently supported")
...@@ -274,7 +274,7 @@ async def create_completion(raw_request: Request): ...@@ -274,7 +274,7 @@ async def create_completion(raw_request: Request):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="CacheFlow OpenAI-Compatible RESTful API server." description="vLLM OpenAI-Compatible RESTful API server."
) )
parser.add_argument("--host", type=str, default="localhost", help="host name") parser.add_argument("--host", type=str, default="localhost", help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number") parser.add_argument("--port", type=int, default=8000, help="port number")
......
...@@ -4,7 +4,7 @@ from typing import Dict, List, Literal, Optional, Union ...@@ -4,7 +4,7 @@ from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from cacheflow.utils import random_uuid from vllm.utils import random_uuid
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
...@@ -34,7 +34,7 @@ class ModelCard(BaseModel): ...@@ -34,7 +34,7 @@ class ModelCard(BaseModel):
id: str id: str
object: str = "model" object: str = "model"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "cacheflow" owned_by: str = "vllm"
root: Optional[str] = None root: Optional[str] = None
parent: Optional[str] = None parent: Optional[str] = None
permission: List[ModelPermission] = Field(default_factory=list) permission: List[ModelPermission] = Field(default_factory=list)
...@@ -82,7 +82,7 @@ class CompletionRequest(BaseModel): ...@@ -82,7 +82,7 @@ class CompletionRequest(BaseModel):
best_of: Optional[int] = None best_of: Optional[int] = None
logit_bias: Optional[Dict[str, float]] = None logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None user: Optional[str] = None
# Additional parameters supported by cacheflow # Additional parameters supported by vLLM
top_k: Optional[int] = -1 top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False use_beam_search: Optional[bool] = False
......
...@@ -22,7 +22,7 @@ class NewLineFormatter(logging.Formatter): ...@@ -22,7 +22,7 @@ class NewLineFormatter(logging.Formatter):
return msg return msg
_root_logger = logging.getLogger("cacheflow") _root_logger = logging.getLogger("vllm")
_default_handler = None _default_handler = None
......
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.utils import set_random_seed
__all__ = [
"InputMetadata",
"get_model",
"set_random_seed",
]
...@@ -3,8 +3,8 @@ from typing import Dict, List, Tuple ...@@ -3,8 +3,8 @@ from typing import Dict, List, Tuple
import torch import torch
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from cacheflow.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from cacheflow.sequence import SequenceData from vllm.sequence import SequenceData
class InputMetadata: class InputMetadata:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment