Commit 99b471c2 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.1

parents 1925d2e9 468d761b
...@@ -32,6 +32,9 @@ class LLM: ...@@ -32,6 +32,9 @@ class LLM:
tokenizer: The name or path of a HuggingFace Transformers tokenizer. tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer. if available, and "slow" will always use the slow tokenizer.
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer. Expect valid prompt_token_ids and None for prompt
from the input.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer. downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed tensor_parallel_size: The number of GPUs to use for distributed
...@@ -42,10 +45,11 @@ class LLM: ...@@ -42,10 +45,11 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead. use `float16` instead.
quantization: The method used to quantize the model weights. Currently, quantization: The method used to quantize the model weights. Currently,
we support "awq", "gptq" and "squeezellm". If None, we first check we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
the `quantization_config` attribute in the model config file. If If None, we first check the `quantization_config` attribute in the
that is None, we assume the model weights are not quantized and use model config file. If that is None, we assume the model weights are
`dtype` to determine the data type of the weights. not quantized and use `dtype` to determine the data type of
the weights.
revision: The specific model version to use. It can be a branch name, revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a tokenizer_revision: The specific tokenizer version to use. It can be a
...@@ -75,6 +79,7 @@ class LLM: ...@@ -75,6 +79,7 @@ class LLM:
model: str, model: str,
tokenizer: Optional[str] = None, tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto", tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False, trust_remote_code: bool = False,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
dtype: str = "auto", dtype: str = "auto",
...@@ -86,7 +91,7 @@ class LLM: ...@@ -86,7 +91,7 @@ class LLM:
swap_space: int = 4, swap_space: int = 4,
enforce_eager: bool = False, enforce_eager: bool = False,
max_context_len_to_capture: int = 8192, max_context_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = True, disable_custom_all_reduce: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
...@@ -95,6 +100,7 @@ class LLM: ...@@ -95,6 +100,7 @@ class LLM:
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
dtype=dtype, dtype=dtype,
...@@ -126,7 +132,8 @@ class LLM: ...@@ -126,7 +132,8 @@ class LLM:
def generate( def generate(
self, self,
prompts: Optional[Union[str, List[str]]] = None, prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
...@@ -141,7 +148,10 @@ class LLM: ...@@ -141,7 +148,10 @@ class LLM:
Args: Args:
prompts: A list of prompts to generate completions for. prompts: A list of prompts to generate completions for.
sampling_params: The sampling parameters for text generation. If sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters. None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
prompt_token_ids: A list of token IDs for the prompts. If None, we prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs. use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
...@@ -155,6 +165,10 @@ class LLM: ...@@ -155,6 +165,10 @@ class LLM:
if prompts is None and prompt_token_ids is None: if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be " raise ValueError("Either prompts or prompt_token_ids must be "
"provided.") "provided.")
if self.llm_engine.model_config.skip_tokenizer_init \
and prompts is not None:
raise ValueError("prompts must be None if skip_tokenizer_init "
"is True")
if isinstance(prompts, str): if isinstance(prompts, str):
# Convert a single prompt to a list. # Convert a single prompt to a list.
prompts = [prompts] prompts = [prompts]
...@@ -162,23 +176,33 @@ class LLM: ...@@ -162,23 +176,33 @@ class LLM:
and len(prompts) != len(prompt_token_ids)): and len(prompts) != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids " raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.") "must be the same.")
if prompts is not None:
num_requests = len(prompts)
else:
assert prompt_token_ids is not None
num_requests = len(prompt_token_ids)
if sampling_params is None: if sampling_params is None:
# Use default sampling params. # Use default sampling params.
sampling_params = SamplingParams() sampling_params = SamplingParams()
elif isinstance(sampling_params,
list) and len(sampling_params) != num_requests:
raise ValueError("The lengths of prompts and sampling_params "
"must be the same.")
if multi_modal_data: if multi_modal_data:
multi_modal_data.data = multi_modal_data.data.to(torch.float16) multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine. # Add requests to the engine.
num_requests = len(prompts) if prompts is not None else len(
prompt_token_ids)
for i in range(num_requests): for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[ token_ids = None if prompt_token_ids is None else prompt_token_ids[
i] i]
self._add_request( self._add_request(
prompt, prompt,
sampling_params, sampling_params[i]
if isinstance(sampling_params, list) else sampling_params,
token_ids, token_ids,
lora_request=lora_request, lora_request=lora_request,
# Get ith image while maintaining the batch dim. # Get ith image while maintaining the batch dim.
...@@ -227,4 +251,4 @@ class LLM: ...@@ -227,4 +251,4 @@ class LLM:
# This is necessary because some requests may be finished earlier than # This is necessary because some requests may be finished earlier than
# its previous requests. # its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id)) outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs return outputs
\ No newline at end of file
...@@ -18,6 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -18,6 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest, ErrorResponse) CompletionRequest, ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
...@@ -26,8 +27,8 @@ from vllm.usage.usage_lib import UsageContext ...@@ -26,8 +27,8 @@ from vllm.usage.usage_lib import UsageContext
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat: OpenAIServingChat = None openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion = None openai_serving_completion: OpenAIServingCompletion
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -95,6 +96,7 @@ async def create_chat_completion(request: ChatCompletionRequest, ...@@ -95,6 +96,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return StreamingResponse(content=generator, return StreamingResponse(content=generator,
media_type="text/event-stream") media_type="text/event-stream")
else: else:
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
...@@ -127,7 +129,8 @@ if __name__ == "__main__": ...@@ -127,7 +129,8 @@ if __name__ == "__main__":
@app.middleware("http") @app.middleware("http")
async def authentication(request: Request, call_next): async def authentication(request: Request, call_next):
if not request.url.path.startswith("/v1"): root_path = "" if args.root_path is None else args.root_path
if not request.url.path.startswith(f"{root_path}/v1"):
return await call_next(request) return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token: if request.headers.get("Authorization") != "Bearer " + token:
return JSONResponse(content={"error": "Unauthorized"}, return JSONResponse(content={"error": "Unauthorized"},
...@@ -149,18 +152,18 @@ if __name__ == "__main__": ...@@ -149,18 +152,18 @@ if __name__ == "__main__":
logger.info(f"args: {args}") logger.info(f"args: {args}")
if args.served_model_name is not None: if args.served_model_name is not None:
served_model = args.served_model_name served_model_names = args.served_model_name
else: else:
served_model = args.model served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args( engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER) engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
openai_serving_chat = OpenAIServingChat(engine, served_model, openai_serving_chat = OpenAIServingChat(engine, served_model_names,
args.response_role, args.response_role,
args.lora_modules, args.lora_modules,
args.chat_template) args.chat_template)
openai_serving_completion = OpenAIServingCompletion( openai_serving_completion = OpenAIServingCompletion(
engine, served_model, args.lora_modules) engine, served_model_names, args.lora_modules)
app.root_path = args.root_path app.root_path = args.root_path
uvicorn.run(app, uvicorn.run(app,
......
...@@ -54,11 +54,15 @@ def make_arg_parser(): ...@@ -54,11 +54,15 @@ def make_arg_parser():
help="If provided, the server will require this key " help="If provided, the server will require this key "
"to be presented in the header.") "to be presented in the header.")
parser.add_argument("--served-model-name", parser.add_argument("--served-model-name",
nargs="+",
type=str, type=str,
default=None, default=None,
help="The model name used in the API. If not " help="The model name(s) used in the API. If multiple "
"specified, the model name will be the same as " "names are provided, the server will respond to any "
"the huggingface name.") "of the provided names. The model name in the model "
"field of a response will be the first name in this "
"list. If not specified, the model name will be the "
"same as the `--model` argument.")
parser.add_argument( parser.add_argument(
"--lora-modules", "--lora-modules",
type=str, type=str,
......
...@@ -5,6 +5,7 @@ from typing import Dict, List, Literal, Optional, Union ...@@ -5,6 +5,7 @@ from typing import Dict, List, Literal, Optional, Union
import torch import torch
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from typing_extensions import Annotated
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -30,7 +31,7 @@ class ModelPermission(BaseModel): ...@@ -30,7 +31,7 @@ class ModelPermission(BaseModel):
allow_fine_tuning: bool = False allow_fine_tuning: bool = False
organization: str = "*" organization: str = "*"
group: Optional[str] = None group: Optional[str] = None
is_blocking: str = False is_blocking: bool = False
class ModelCard(BaseModel): class ModelCard(BaseModel):
...@@ -56,7 +57,7 @@ class UsageInfo(BaseModel): ...@@ -56,7 +57,7 @@ class UsageInfo(BaseModel):
class ResponseFormat(BaseModel): class ResponseFormat(BaseModel):
# type must be "json_object" or "text" # type must be "json_object" or "text"
type: str = Literal["text", "json_object"] type: Literal["text", "json_object"]
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
...@@ -133,6 +134,12 @@ class ChatCompletionRequest(BaseModel): ...@@ -133,6 +134,12 @@ class ChatCompletionRequest(BaseModel):
description=( description=(
"If specified, the output will follow the context free grammar."), "If specified, the output will follow the context free grammar."),
) )
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"))
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
...@@ -146,6 +153,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -146,6 +153,7 @@ class ChatCompletionRequest(BaseModel):
def logit_bias_logits_processor( def logit_bias_logits_processor(
token_ids: List[int], token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor: logits: torch.Tensor) -> torch.Tensor:
assert self.logit_bias is not None
for token_id, bias in self.logit_bias.items(): for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec # Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias)) bias = min(100, max(-100, bias))
...@@ -207,7 +215,7 @@ class CompletionRequest(BaseModel): ...@@ -207,7 +215,7 @@ class CompletionRequest(BaseModel):
logit_bias: Optional[Dict[str, float]] = None logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None logprobs: Optional[int] = None
max_tokens: Optional[int] = 16 max_tokens: Optional[int] = 16
n: Optional[int] = 1 n: int = 1
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
seed: Optional[int] = None seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
...@@ -229,6 +237,7 @@ class CompletionRequest(BaseModel): ...@@ -229,6 +237,7 @@ class CompletionRequest(BaseModel):
min_tokens: Optional[int] = 0 min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: end-completion-sampling-params # doc: end-completion-sampling-params
# doc: begin-completion-extra-params # doc: begin-completion-extra-params
...@@ -264,6 +273,12 @@ class CompletionRequest(BaseModel): ...@@ -264,6 +273,12 @@ class CompletionRequest(BaseModel):
description=( description=(
"If specified, the output will follow the context free grammar."), "If specified, the output will follow the context free grammar."),
) )
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"))
# doc: end-completion-extra-params # doc: end-completion-extra-params
...@@ -276,6 +291,7 @@ class CompletionRequest(BaseModel): ...@@ -276,6 +291,7 @@ class CompletionRequest(BaseModel):
def logit_bias_logits_processor( def logit_bias_logits_processor(
token_ids: List[int], token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor: logits: torch.Tensor) -> torch.Tensor:
assert self.logit_bias is not None
for token_id, bias in self.logit_bias.items(): for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec # Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias)) bias = min(100, max(-100, bias))
...@@ -309,6 +325,7 @@ class CompletionRequest(BaseModel): ...@@ -309,6 +325,7 @@ class CompletionRequest(BaseModel):
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
) )
@model_validator(mode="before") @model_validator(mode="before")
......
...@@ -24,12 +24,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -24,12 +24,12 @@ class OpenAIServingChat(OpenAIServing):
def __init__(self, def __init__(self,
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
served_model: str, served_model_names: List[str],
response_role: str, response_role: str,
lora_modules: Optional[List[LoRA]] = None, lora_modules: Optional[List[LoRA]] = None,
chat_template=None): chat_template=None):
super().__init__(engine=engine, super().__init__(engine=engine,
served_model=served_model, served_model_names=served_model_names,
lora_modules=lora_modules) lora_modules=lora_modules)
self.response_role = response_role self.response_role = response_role
self._load_chat_template(chat_template) self._load_chat_template(chat_template)
...@@ -63,13 +63,18 @@ class OpenAIServingChat(OpenAIServing): ...@@ -63,13 +63,18 @@ class OpenAIServingChat(OpenAIServing):
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
try: try:
token_ids = self._validate_prompt_and_tokenize(request, # Tokenize/detokenize depending on prompt format (string/token list)
prompt=prompt) prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
request, prompt=prompt)
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
decoding_config = self.engine.engine.decoding_config
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logits_processor = ( guided_decode_logits_processor = (
await get_guided_decoding_logits_processor( await get_guided_decoding_logits_processor(
request, await self.engine.get_tokenizer())) guided_decoding_backend, request, await
self.engine.get_tokenizer()))
if guided_decode_logits_processor: if guided_decode_logits_processor:
if sampling_params.logits_processors is None: if sampling_params.logits_processors is None:
sampling_params.logits_processors = [] sampling_params.logits_processors = []
...@@ -78,8 +83,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -78,8 +83,8 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt, sampling_params, result_generator = self.engine.generate(prompt_text, sampling_params,
request_id, token_ids, request_id, prompt_ids,
lora_request) lora_request)
# Streaming response # Streaming response
if request.stream: if request.stream:
...@@ -104,18 +109,18 @@ class OpenAIServingChat(OpenAIServing): ...@@ -104,18 +109,18 @@ class OpenAIServingChat(OpenAIServing):
result_generator: AsyncIterator[RequestOutput], request_id: str result_generator: AsyncIterator[RequestOutput], request_id: str
) -> Union[ErrorResponse, AsyncGenerator[str, None]]: ) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
model_name = request.model model_name = self.served_model_names[0]
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type = "chat.completion.chunk" chunk_object_type = "chat.completion.chunk"
first_iteration = True first_iteration = True
# Send response for each token for each request.n (index) # Send response for each token for each request.n (index)
assert request.n is not None
previous_texts = [""] * request.n previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n finish_reason_sent = [False] * request.n
try: try:
async for res in result_generator: async for res in result_generator:
res: RequestOutput
# We need to do it here, because if there are exceptions in # We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST # the result_generator, it needs to be sent as the FIRST
# response (by the try...catch). # response (by the try...catch).
...@@ -246,7 +251,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -246,7 +251,7 @@ class OpenAIServingChat(OpenAIServing):
result_generator: AsyncIterator[RequestOutput], result_generator: AsyncIterator[RequestOutput],
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = request.model model_name = self.served_model_names[0]
created_time = int(time.time()) created_time = int(time.time())
final_res: RequestOutput = None final_res: RequestOutput = None
...@@ -314,23 +319,30 @@ class OpenAIServingChat(OpenAIServing): ...@@ -314,23 +319,30 @@ class OpenAIServingChat(OpenAIServing):
return response return response
def _load_chat_template(self, chat_template): def _load_chat_template(self, chat_template):
tokenizer = self.tokenizer
if chat_template is not None: if chat_template is not None:
try: try:
with open(chat_template, "r") as f: with open(chat_template, "r") as f:
self.tokenizer.chat_template = f.read() tokenizer.chat_template = f.read()
except OSError: except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to # If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly # ensure we decode so our escape are interpreted correctly
self.tokenizer.chat_template = codecs.decode( tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape") chat_template, "unicode_escape")
logger.info( logger.info(
f"Using supplied chat template:\n{self.tokenizer.chat_template}" f"Using supplied chat template:\n{tokenizer.chat_template}")
) elif tokenizer.chat_template is not None:
elif self.tokenizer.chat_template is not None:
logger.info( logger.info(
f"Using default chat template:\n{self.tokenizer.chat_template}" f"Using default chat template:\n{tokenizer.chat_template}")
)
else: else:
logger.warning( logger.warning(
"No chat template provided. Chat API will not work.") "No chat template provided. Chat API will not work.")
import asyncio
import time import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional, Tuple) Optional, Tuple)
...@@ -17,7 +16,7 @@ from vllm.logger import init_logger ...@@ -17,7 +16,7 @@ from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.utils import random_uuid from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -50,49 +49,14 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]: ...@@ -50,49 +49,14 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
return prompt_is_tokens, prompts return prompt_is_tokens, prompts
def merge_async_iterators(*iterators):
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue = asyncio.Queue()
finished = [False] * len(iterators)
async def producer(i, iterator):
try:
async for item in iterator:
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True
_tasks = [
asyncio.create_task(producer(i, iterator))
for i, iterator in enumerate(iterators)
]
async def consumer():
while not all(finished) or not queue.empty():
item = await queue.get()
if isinstance(item, Exception):
raise item
yield item
await asyncio.gather(*_tasks)
return consumer()
class OpenAIServingCompletion(OpenAIServing): class OpenAIServingCompletion(OpenAIServing):
def __init__(self, def __init__(self,
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
served_model: str, served_model_names: List[str],
lora_modules: Optional[List[LoRA]] = None): lora_modules: Optional[List[LoRA]] = None):
super().__init__(engine=engine, super().__init__(engine=engine,
served_model=served_model, served_model_names=served_model_names,
lora_modules=lora_modules) lora_modules=lora_modules)
async def create_completion(self, request: CompletionRequest, async def create_completion(self, request: CompletionRequest,
...@@ -115,7 +79,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -115,7 +79,7 @@ class OpenAIServingCompletion(OpenAIServing):
return self.create_error_response( return self.create_error_response(
"suffix is not currently supported") "suffix is not currently supported")
model_name = request.model model_name = self.served_model_names[0]
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time()) created_time = int(time.time())
...@@ -124,9 +88,13 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -124,9 +88,13 @@ class OpenAIServingCompletion(OpenAIServing):
try: try:
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
decoding_config = self.engine.engine.decoding_config
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logit_processor = ( guided_decode_logit_processor = (
await get_guided_decoding_logits_processor( await get_guided_decoding_logits_processor(
request, await self.engine.get_tokenizer())) guided_decoding_backend, request, await
self.engine.get_tokenizer()))
if guided_decode_logit_processor is not None: if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None: if sampling_params.logits_processors is None:
sampling_params.logits_processors = [] sampling_params.logits_processors = []
...@@ -136,17 +104,24 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -136,17 +104,24 @@ class OpenAIServingCompletion(OpenAIServing):
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if prompt_is_tokens: if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize( prompt_formats = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt) request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
else: else:
input_ids = self._validate_prompt_and_tokenize( prompt_formats = self._validate_prompt_and_tokenize(
request, prompt=prompt) request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats
generators.append( generators.append(
self.engine.generate(prompt, self.engine.generate(prompt_text,
sampling_params, sampling_params,
f"{request_id}-{i}", f"{request_id}-{i}",
prompt_token_ids=input_ids, prompt_token_ids=prompt_ids,
lora_request=lora_request)) lora_request=lora_request))
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
...@@ -210,6 +185,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -210,6 +185,7 @@ class OpenAIServingCompletion(OpenAIServing):
model_name: str, model_name: str,
num_prompts: int, num_prompts: int,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
assert request.n is not None
previous_texts = [""] * request.n * num_prompts previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts has_echoed = [False] * request.n * num_prompts
...@@ -227,6 +203,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -227,6 +203,7 @@ class OpenAIServingCompletion(OpenAIServing):
# TODO(simon): optimize the performance by avoiding full # TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending. # text O(n^2) sending.
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
# only return the prompt # only return the prompt
delta_text = res.prompt delta_text = res.prompt
...@@ -304,7 +281,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -304,7 +281,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int, created_time: int,
model_name: str, model_name: str,
) -> CompletionResponse: ) -> CompletionResponse:
choices = [] choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
num_generated_tokens = 0 num_generated_tokens = 0
for final_res in final_res_batch: for final_res in final_res_batch:
...@@ -314,13 +291,15 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -314,13 +291,15 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_text = final_res.prompt prompt_text = final_res.prompt
for output in final_res.outputs: for output in final_res.outputs:
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids token_ids = prompt_token_ids
top_logprobs = prompt_logprobs top_logprobs = prompt_logprobs
output_text = prompt_text output_text = prompt_text
elif request.echo and request.max_tokens > 0: elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids token_ids = prompt_token_ids + output.token_ids
top_logprobs = prompt_logprobs + output.logprobs top_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs else None)
output_text = prompt_text + output.text output_text = prompt_text + output.text
else: else:
token_ids = output.token_ids token_ids = output.token_ids
...@@ -328,6 +307,9 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -328,6 +307,9 @@ class OpenAIServingCompletion(OpenAIServing):
output_text = output.text output_text = output.text
if request.logprobs is not None: if request.logprobs is not None:
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs( logprobs = self._create_logprobs(
token_ids=token_ids, token_ids=token_ids,
top_logprobs=top_logprobs, top_logprobs=top_logprobs,
......
...@@ -2,7 +2,11 @@ import asyncio ...@@ -2,7 +2,11 @@ import asyncio
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Tuple, Union
from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
...@@ -27,10 +31,10 @@ class OpenAIServing: ...@@ -27,10 +31,10 @@ class OpenAIServing:
def __init__(self, def __init__(self,
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
served_model: str, served_model_names: List[str],
lora_modules=Optional[List[LoRA]]): lora_modules=Optional[List[LoRA]]):
self.engine = engine self.engine = engine
self.served_model = served_model self.served_model_names = served_model_names
if lora_modules is None: if lora_modules is None:
self.lora_requests = [] self.lora_requests = []
else: else:
...@@ -43,7 +47,8 @@ class OpenAIServing: ...@@ -43,7 +47,8 @@ class OpenAIServing:
] ]
self.max_model_len = 0 self.max_model_len = 0
self.tokenizer = None # Lazy initialized
self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
try: try:
event_loop = asyncio.get_running_loop() event_loop = asyncio.get_running_loop()
...@@ -66,18 +71,21 @@ class OpenAIServing: ...@@ -66,18 +71,21 @@ class OpenAIServing:
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
engine_model_config.tokenizer, engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode, tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code) tokenizer_revision=engine_model_config.tokenizer_revision,
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")
async def show_available_models(self) -> ModelList: async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
model_cards = [ model_cards = [
ModelCard(id=self.served_model, ModelCard(id=served_model_name,
root=self.served_model, root=self.served_model_names[0],
permission=[ModelPermission()]) permission=[ModelPermission()])
for served_model_name in self.served_model_names
] ]
lora_cards = [ lora_cards = [
ModelCard(id=lora.lora_name, ModelCard(id=lora.lora_name,
root=self.served_model, root=self.served_model_names[0],
permission=[ModelPermission()]) permission=[ModelPermission()])
for lora in self.lora_requests for lora in self.lora_requests
] ]
...@@ -87,7 +95,7 @@ class OpenAIServing: ...@@ -87,7 +95,7 @@ class OpenAIServing:
def _create_logprobs( def _create_logprobs(
self, self,
token_ids: List[int], token_ids: List[int],
top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None, top_logprobs: List[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: Optional[int] = None, num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0, initial_text_offset: int = 0,
) -> LogProbs: ) -> LogProbs:
...@@ -96,27 +104,36 @@ class OpenAIServing: ...@@ -96,27 +104,36 @@ class OpenAIServing:
last_token_len = 0 last_token_len = 0
if num_output_top_logprobs: if num_output_top_logprobs:
logprobs.top_logprobs = [] logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None: if step_top_logprobs is None:
token_logprob = step_top_logprobs[token_id].logprob token = self.tokenizer.decode(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(None)
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append(None)
else: else:
token_logprob = None token_logprob = step_top_logprobs[token_id].logprob
token = step_top_logprobs[token_id].decoded_token token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token) logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob) logprobs.token_logprobs.append(token_logprob)
if num_output_top_logprobs:
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
p.decoded_token: max(p.logprob, -9999.0)
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
if len(logprobs.text_offset) == 0: if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset) logprobs.text_offset.append(initial_text_offset)
else: else:
logprobs.text_offset.append(logprobs.text_offset[-1] + logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len) last_token_len)
last_token_len = len(token) last_token_len = len(token)
if num_output_top_logprobs:
logprobs.top_logprobs.append({
p.decoded_token: p.logprob
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
return logprobs return logprobs
def create_error_response( def create_error_response(
...@@ -142,18 +159,18 @@ class OpenAIServing: ...@@ -142,18 +159,18 @@ class OpenAIServing:
return json_str return json_str
async def _check_model(self, request) -> Optional[ErrorResponse]: async def _check_model(self, request) -> Optional[ErrorResponse]:
if request.model == self.served_model: if request.model in self.served_model_names:
return return None
if request.model in [lora.lora_name for lora in self.lora_requests]: if request.model in [lora.lora_name for lora in self.lora_requests]:
return return None
return self.create_error_response( return self.create_error_response(
message=f"The model `{request.model}` does not exist.", message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError", err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND) status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(self, request) -> Optional[LoRARequest]: def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
if request.model == self.served_model: if request.model in self.served_model_names:
return return None
for lora in self.lora_requests: for lora in self.lora_requests:
if request.model == lora.lora_name: if request.model == lora.lora_name:
return lora return lora
...@@ -161,21 +178,40 @@ class OpenAIServing: ...@@ -161,21 +178,40 @@ class OpenAIServing:
raise ValueError("The model `{request.model}` does not exist.") raise ValueError("The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize( def _validate_prompt_and_tokenize(
self, self,
request: Union[ChatCompletionRequest, CompletionRequest], request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None, prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None) -> List[int]: prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
) -> Tuple[List[int], str]:
if not (prompt or prompt_ids): if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.") raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids): if (prompt and prompt_ids):
raise ValueError( raise ValueError(
"Only one of prompt or prompt_ids should be provided.") "Only one of prompt or prompt_ids should be provided.")
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer( if prompt_ids is None:
prompt).input_ids tokenizer_kwargs = {} if truncate_prompt_tokens is None else {
"truncation": True,
"max_length": truncate_prompt_tokens,
}
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:]
else:
input_ids = prompt_ids
input_text = prompt if prompt is not None else self.tokenizer.decode(
prompt_ids)
token_num = len(input_ids) token_num = len(input_ids)
if request.max_tokens is None: if request.max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.", )
request.max_tokens = self.max_model_len - token_num request.max_tokens = self.max_model_len - token_num
if token_num + request.max_tokens > self.max_model_len: if token_num + request.max_tokens > self.max_model_len:
...@@ -187,4 +223,4 @@ class OpenAIServing: ...@@ -187,4 +223,4 @@ class OpenAIServing:
f"{request.max_tokens} in the completion). " f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", ) f"Please reduce the length of the messages or completion.", )
else: else:
return input_ids return input_ids, input_text
import os
from typing import Dict, List, Set, Tuple
import torch
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
logger = init_logger(__name__)
class CPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu"
assert self.lora_config is None, "cpu backend doesn't support LoRA"
self.model_config = _verify_and_get_model_config(self.model_config)
self.cache_config = _verify_and_get_cache_config(self.cache_config)
self.scheduler_config = _verify_and_get_scheduler_config(
self.scheduler_config)
# Instantiate the worker and load the model to CPU.
self._init_worker()
def _init_worker(self):
from vllm.worker.cpu_worker import CPUWorker
assert self.parallel_config.world_size == 1, (
"CPUExecutor only supports single CPU socket currently.")
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = CPUWorker(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
# NOTE: `cpu block` for CPU backend is located on CPU memory but is
# referred as `gpu block`. Because we want to reuse the existing block
# management procedure.
logger.info(f"# CPU blocks: {num_gpu_blocks}")
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def check_health(self) -> None:
# CPUExecutor will always be healthy as long as
# it's running.
return
class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy)
return output
async def check_health_async(self) -> None:
# CPUExecutor will always be healthy as long as
# it's running.
return
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
if config.dtype == torch.float16:
logger.warning("float16 is not supported on CPU, casting to bfloat16.")
config.dtype = torch.bfloat16
if not config.enforce_eager:
logger.warning(
"CUDA graph is not supported on CPU, fallback to the eager "
"mode.")
config.enforce_eager = True
return config
def _verify_and_get_scheduler_config(
config: SchedulerConfig) -> SchedulerConfig:
if config.chunked_prefill_enabled:
logger.warning("Chunked prefill is not supported on CPU, disable it.")
config.chunked_prefill_enabled = False
return config
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
_GB = 1 << 30
if config.enable_prefix_caching:
logger.warning("Prefix caching is not supported on CPU, disable it.")
config.enable_prefix_caching = False
kv_cache_space_str = os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")
kv_cache_space = int(kv_cache_space_str)
if kv_cache_space >= 0:
if kv_cache_space == 0:
config.cpu_kvcache_space_bytes = 4 * _GB # type: ignore
logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
"for CPU backend is not set, using 4 by default.")
else:
config.cpu_kvcache_space_bytes = kv_cache_space * _GB # type: ignore
else:
raise RuntimeError(
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
f" {kv_cache_space}, expect a positive integer value.")
return config
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional from typing import Dict, List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
...@@ -15,7 +16,6 @@ class ExecutorBase(ABC): ...@@ -15,7 +16,6 @@ class ExecutorBase(ABC):
that can execute the model on multiple devices. that can execute the model on multiple devices.
""" """
@abstractmethod
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
...@@ -23,9 +23,48 @@ class ExecutorBase(ABC): ...@@ -23,9 +23,48 @@ class ExecutorBase(ABC):
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None: ) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.speculative_config = speculative_config
self._init_executor()
@abstractmethod
def _init_executor(self) -> None:
pass
@abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
Normally, this should simply delegate to the underlying Worker. Some
ExecutorBase may require modification of the result, e.g. to ensure the
selected cache sizes are compatible with all workers.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError
@abstractmethod
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
...@@ -33,8 +72,9 @@ class ExecutorBase(ABC): ...@@ -33,8 +72,9 @@ class ExecutorBase(ABC):
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: blocks_to_copy: Dict[int, List[int]],
"""Executes one model step on the given sequences.""" num_lookahead_slots: int) -> List[SamplerOutput]:
"""Executes at least one model step on the given sequences."""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
...@@ -46,7 +86,7 @@ class ExecutorBase(ABC): ...@@ -46,7 +86,7 @@ class ExecutorBase(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
...@@ -69,8 +109,7 @@ class ExecutorAsyncBase(ExecutorBase): ...@@ -69,8 +109,7 @@ class ExecutorAsyncBase(ExecutorBase):
"""Executes one model step on the given sequences.""" """Executes one model step on the given sequences."""
raise NotImplementedError raise NotImplementedError
@abstractmethod
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an """Checks if the executor is healthy. If not, it should raise an
exception.""" exception."""
raise NotImplementedError self.check_health()
from typing import Dict, List, Optional from typing import Dict, List, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.utils import check_block_size_valid
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
...@@ -15,31 +12,18 @@ logger = init_logger(__name__) ...@@ -15,31 +12,18 @@ logger = init_logger(__name__)
class GPUExecutor(ExecutorBase): class GPUExecutor(ExecutorBase):
def __init__( def _init_executor(self) -> None:
self, """Initialize the worker and load the model.
model_config: ModelConfig,
cache_config: CacheConfig, If speculative decoding is enabled, we instead create the speculative
parallel_config: ParallelConfig, worker.
scheduler_config: SchedulerConfig, """
device_config: DeviceConfig, if self.speculative_config is None:
lora_config: Optional[LoRAConfig], self._init_non_spec_worker()
vision_language_config: Optional[VisionLanguageConfig], else:
) -> None: self._init_spec_worker()
self.model_config = model_config
self.cache_config = cache_config def _init_non_spec_worker(self):
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
# Instantiate the worker and load the model to GPU.
self._init_worker()
# Profile the memory usage and initialize the cache.
self._init_cache()
def _init_worker(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers # Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker # before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
...@@ -50,72 +34,107 @@ class GPUExecutor(ExecutorBase): ...@@ -50,72 +34,107 @@ class GPUExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
self.driver_worker = Worker( self.driver_worker = Worker(
self.model_config, model_config=self.model_config,
self.parallel_config, parallel_config=self.parallel_config,
self.scheduler_config, scheduler_config=self.scheduler_config,
self.device_config, device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, vision_language_config=self.vision_language_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True, is_driver_worker=True,
) )
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
def _init_cache(self) -> None: def _init_spec_worker(self):
"""Profiles the memory usage and initializes the KV cache. """Initialize a SpecDecodeWorker, using a draft model for proposals.
"""
assert self.speculative_config is not None
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.worker.worker import Worker
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
target_worker = Worker(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=True,
)
draft_worker = MultiStepWorker(
model_config=self.speculative_config.draft_model_config,
parallel_config=self.speculative_config.draft_parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
# TODO allow draft-model specific load config.
load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=True,
)
spec_decode_worker = SpecDecodeWorker.from_workers(
proposer_worker=draft_worker, scorer_worker=target_worker)
The engine first profiles the existing memory usage. assert self.parallel_config.world_size == 1, (
Then, it allocates the remaining memory for KV blocks. "GPUExecutor only supports single GPU.")
self.driver_worker = spec_decode_worker
.. tip:: # Load model handled in spec decode worker.
You may limit the usage of GPU memory self.driver_worker.init_device()
by adjusting the `gpu_memory_utilization` parameter.
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
""" """
# Get the maximum number of blocks that can be allocated on GPU and CPU. return self.driver_worker.determine_num_available_blocks()
num_gpu_blocks, num_cpu_blocks = (
self.driver_worker.profile_num_available_blocks(
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.
gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
))
if self.cache_config.forced_num_gpu_blocks is not None:
forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks
logger.info(f"Replacing profiled {num_gpu_blocks=} with "
f"{forced_num_gpu_blocks=}")
num_gpu_blocks = forced_num_gpu_blocks
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info(f"# GPU blocks: {num_gpu_blocks}, " logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}") f"# CPU blocks: {num_cpu_blocks}")
check_block_size_valid(num_gpu_blocks, self.cache_config.block_size, self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
self.model_config.max_model_len)
self.cache_config.num_gpu_blocks = num_gpu_blocks def execute_model(
self.cache_config.num_cpu_blocks = num_cpu_blocks self,
seq_group_metadata_list: List[SequenceGroupMetadata],
# Initialize the cache. blocks_to_swap_in: Dict[int, int],
self.driver_worker.init_cache_engine(cache_config=self.cache_config) blocks_to_swap_out: Dict[int, int],
# Warm up the model. This includes capturing the model into CUDA graph blocks_to_copy: Dict[int, List[int]],
# if enforce_eager is False. num_lookahead_slots: int,
self.driver_worker.warm_up_model() ) -> List[SamplerOutput]:
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
output = self.driver_worker.execute_model( output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
num_lookahead_slots=num_lookahead_slots,
) )
return output return output
...@@ -127,7 +146,7 @@ class GPUExecutor(ExecutorBase): ...@@ -127,7 +146,7 @@ class GPUExecutor(ExecutorBase):
assert lora_id > 0, "lora_id must be greater than 0." assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id) return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras() return self.driver_worker.list_loras()
def check_health(self) -> None: def check_health(self) -> None:
...@@ -151,8 +170,3 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): ...@@ -151,8 +170,3 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy) blocks_to_copy=blocks_to_copy)
return output return output
async def check_health_async(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return
from typing import Dict, List, Optional from typing import Dict, List, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import make_async
logger = init_logger(__name__) logger = init_logger(__name__)
class NeuronExecutor(ExecutorBase): class NeuronExecutor(ExecutorBase):
def __init__( def _init_executor(self) -> None:
self, assert (self.lora_config is
model_config: ModelConfig, None), "LoRA is not supported for Neuron backend."
cache_config: CacheConfig, assert (not self.speculative_config
parallel_config: ParallelConfig, ), "Speculative decoding not yet supported for Neuron backend."
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
assert lora_config is None, "LoRA is not supported for Neuron backend."
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
# Set the number of GPU blocks to be the same as the maximum number of
# sequences that can be processed in a single batch. This is equivalent
# to schedule without PagedAttention.
self.cache_config.num_gpu_blocks = self.scheduler_config.max_num_seqs
self.cache_config.num_cpu_blocks = 0
# Instantiate the worker and load the model to the device. # Instantiate the worker and load the model to the device.
self._init_worker() self._init_worker()
...@@ -46,36 +28,68 @@ class NeuronExecutor(ExecutorBase): ...@@ -46,36 +28,68 @@ class NeuronExecutor(ExecutorBase):
self.parallel_config, self.parallel_config,
self.scheduler_config, self.scheduler_config,
self.device_config, self.device_config,
self.cache_config,
) )
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(self, def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int) -> List[SamplerOutput]:
assert (blocks_to_swap_in == {} and blocks_to_swap_out == {} assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
and blocks_to_copy == {}), ( and blocks_to_copy == {}), (
"Cache operations are not supported for Neuron backend.") "Cache operations are not supported for Neuron backend.")
assert num_lookahead_slots == 0, (
"lookahead not supported for Neuron backend.")
output = self.driver_worker.execute_model( output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list) seq_group_metadata_list=seq_group_metadata_list)
return output return output
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError( return self.driver_worker.add_lora(lora_request)
"LoRA is not implemented for neuron backend.")
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError( return self.driver_worker.remove_lora(lora_id)
"LoRA is not implemented for neuron backend.")
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
raise NotImplementedError( return self.driver_worker.list_loras()
"LoRA is not implemented for neuron backend.")
def check_health(self) -> None: def check_health(self) -> None:
# NeuronExecutor will always be healthy as long as # NeuronExecutor will always be healthy as long as
# it's running. # it's running.
return return
class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list, )
return output
async def check_health_async(self) -> None:
# NeuronExecutor will always be healthy as long as
# it's running.
return
import asyncio import asyncio
import copy
import os import os
import pickle import pickle
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.engine.ray_utils import RayWorkerWrapper, ray
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.utils import check_block_size_valid
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async, set_cuda_visible_devices) get_vllm_instance_id, make_async)
if ray is not None: if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
...@@ -32,23 +29,9 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) ...@@ -32,23 +29,9 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
class RayGPUExecutor(ExecutorBase): class RayGPUExecutor(ExecutorBase):
def __init__( def _init_executor(self) -> None:
self, assert (not self.speculative_config
model_config: ModelConfig, ), "Speculative decoding not yet supported for RayGPU backend."
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
assert self.parallel_config.worker_use_ray assert self.parallel_config.worker_use_ray
placement_group = self.parallel_config.placement_group placement_group = self.parallel_config.placement_group
...@@ -61,13 +44,25 @@ class RayGPUExecutor(ExecutorBase): ...@@ -61,13 +44,25 @@ class RayGPUExecutor(ExecutorBase):
# Create the parallel GPU workers. # Create the parallel GPU workers.
self._init_workers_ray(placement_group) self._init_workers_ray(placement_group)
# Profile the memory usage and initialize the cache.
self._init_cache()
self.forward_dag = None self.forward_dag = None
if USE_RAY_COMPILED_DAG: if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag() self.forward_dag = self._compiled_ray_dag()
def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env.
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
runtime_env.update({
"nsight": {
"t": "cuda,cudnn,cublas",
"o": "'worker_process_%p'",
"cuda-graph-trace": "node",
}
})
return ray_remote_kwargs
def _init_workers_ray(self, placement_group: "PlacementGroup", def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs): **ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1: if self.parallel_config.tensor_parallel_size == 1:
...@@ -79,9 +74,13 @@ class RayGPUExecutor(ExecutorBase): ...@@ -79,9 +74,13 @@ class RayGPUExecutor(ExecutorBase):
# The driver dummy worker does not actually use any resources. # The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker. # It holds the resource for the driver worker.
self.driver_dummy_worker: RayWorkerVllm = None self.driver_dummy_worker: RayWorkerWrapper = None
# The remaining workers are the actual ray actors. # The remaining workers are the actual ray actors.
self.workers: List[RayWorkerVllm] = [] self.workers: List[RayWorkerWrapper] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)
# Create the workers. # Create the workers.
driver_ip = get_ip() driver_ip = get_ip()
...@@ -98,13 +97,22 @@ class RayGPUExecutor(ExecutorBase): ...@@ -98,13 +97,22 @@ class RayGPUExecutor(ExecutorBase):
num_gpus=num_gpus, num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy, scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorkerVllm).remote(self.model_config.trust_remote_code) )(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
trust_remote_code=self.model_config.trust_remote_code,
)
worker_ip = ray.get(worker.get_node_ip.remote()) worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None: if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it # If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process. # as the resource holder for the driver process.
self.driver_dummy_worker = worker self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
trust_remote_code=self.model_config.trust_remote_code,
)
else: else:
# Else, added to the list of workers. # Else, added to the list of workers.
self.workers.append(worker) self.workers.append(worker)
...@@ -116,77 +124,59 @@ class RayGPUExecutor(ExecutorBase): ...@@ -116,77 +124,59 @@ class RayGPUExecutor(ExecutorBase):
"GPU node.") "GPU node.")
# Get the set of GPU IDs used on each node. # Get the set of GPU IDs used on each node.
driver_node_id, driver_gpu_ids = ray.get( worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
self.driver_dummy_worker.get_node_and_gpu_ids.remote()) use_dummy_driver=True)
worker_node_and_gpu_ids = ray.get(
[worker.get_node_and_gpu_ids.remote() for worker in self.workers])
node_workers = defaultdict(list) node_workers = defaultdict(list)
node_gpus = defaultdict(list) node_gpus = defaultdict(list)
node_workers[driver_node_id].append(0) for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_gpus[driver_node_id].extend(driver_gpu_ids)
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
start=1):
node_workers[node_id].append(i) node_workers[node_id].append(i)
node_gpus[node_id].extend(gpu_ids) node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items(): for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids) node_gpus[node_id] = sorted(gpu_ids)
# Set CUDA_VISIBLE_DEVICES for the driver and workers. VLLM_INSTANCE_ID = get_vllm_instance_id()
set_cuda_visible_devices(node_gpus[driver_node_id])
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids): # Set environment variables for the driver and workers.
worker.set_cuda_visible_devices.remote(node_gpus[node_id]) all_args_to_update_environment_variables = [({
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])),
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
os.getenv("VLLM_TRACE_FUNCTION", "0"),
}, ) for (node_id, _) in worker_node_and_gpu_ids]
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port()) driver_ip, get_open_port())
# Lazy import the Worker to avoid importing torch.cuda/xformers def collect_arg_helper_func(**kwargs):
# before CUDA_VISIBLE_DEVICES is set in the Worker # avoid writing `{"name": value}` manually
from vllm.worker.worker import Worker return kwargs
model_config = copy.deepcopy(self.model_config) # Initialize the actual workers inside worker wrapper.
parallel_config = copy.deepcopy(self.parallel_config) init_worker_all_kwargs = []
scheduler_config = copy.deepcopy(self.scheduler_config) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
device_config = copy.deepcopy(self.device_config)
lora_config = copy.deepcopy(self.lora_config)
kv_cache_dtype = self.cache_config.cache_dtype
# Initialize the actual workers with the Worker class.
for rank, (worker, (node_id, _)) in enumerate(
zip(self.workers, worker_node_and_gpu_ids),
start=1,
):
local_rank = node_workers[node_id].index(rank) local_rank = node_workers[node_id].index(rank)
worker.init_worker.remote( init_worker_all_kwargs.append(
lambda rank=rank, local_rank=local_rank: Worker( collect_arg_helper_func(
model_config, model_config=self.model_config,
parallel_config, parallel_config=self.parallel_config,
scheduler_config, scheduler_config=self.scheduler_config,
device_config, device_config=self.device_config,
local_rank, cache_config=self.cache_config,
rank, load_config=self.load_config,
distributed_init_method, local_rank=local_rank,
lora_config=lora_config, rank=rank,
kv_cache_dtype=kv_cache_dtype, distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=rank == 0,
)) ))
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
# Initialize the driver worker with the Worker class.
driver_rank = 0
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
self.driver_worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
driver_local_rank,
driver_rank,
distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=True,
)
self._run_workers("init_device") self._run_workers("init_device")
self._run_workers( self._run_workers(
...@@ -195,35 +185,18 @@ class RayGPUExecutor(ExecutorBase): ...@@ -195,35 +185,18 @@ class RayGPUExecutor(ExecutorBase):
max_parallel_loading_workers, max_parallel_loading_workers,
) )
def _init_cache(self) -> None: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the memory usage and initializes the KV cache. """Determine the number of available KV blocks.
The engine will first conduct a profiling of the existing memory usage. This invokes `determine_num_available_blocks` on each worker and takes
Then, it calculate the maximum possible number of GPU and CPU blocks the min of the results, guaranteeing that the selected cache sizes are
that can be allocated with the remaining free memory. compatible with all workers.
More details can be found in the
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
from class :class:`~vllm.worker.Worker`.
Afterwards, as there may be multiple workers, Returns:
we take the minimum number of blocks across all workers - Tuple[num_gpu_blocks, num_cpu_blocks]
to ensure this can be applied to all of them.
Finally, the engine will initialize the KV cache
with the calculated number of blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
""" """
# Get the maximum number of blocks that can be allocated on GPU and CPU. # Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers( num_blocks = self._run_workers("determine_num_available_blocks", )
"profile_num_available_blocks",
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
)
# Since we use a shared centralized controller, we take the minimum # Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory # number of blocks across all workers to make sure all the memory
...@@ -231,32 +204,32 @@ class RayGPUExecutor(ExecutorBase): ...@@ -231,32 +204,32 @@ class RayGPUExecutor(ExecutorBase):
num_gpu_blocks = min(b[0] for b in num_blocks) num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks)
if self.cache_config.forced_num_gpu_blocks is not None: return num_gpu_blocks, num_cpu_blocks
forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks
logger.info(f"Replacing profiled {num_gpu_blocks=} with " def initialize_cache(self, num_gpu_blocks: int,
f"{forced_num_gpu_blocks=}") num_cpu_blocks: int) -> None:
num_gpu_blocks = forced_num_gpu_blocks """Initialize the KV cache in all workers.
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger.info(f"# GPU blocks: {num_gpu_blocks}, " logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}") f"# CPU blocks: {num_cpu_blocks}")
check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
self.model_config.max_model_len)
self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks
# Initialize the cache. self._run_workers("initialize_cache",
self._run_workers("init_cache_engine", cache_config=self.cache_config) num_gpu_blocks=num_gpu_blocks,
# Warm up the model. This includes capturing the model into CUDA graph num_cpu_blocks=num_cpu_blocks)
# if enforce_eager is False.
self._run_workers("warm_up_model")
def execute_model(self, def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int = 0) -> SamplerOutput:
all_outputs = self._run_workers( all_outputs = self._run_workers(
"execute_model", "execute_model",
driver_kwargs={ driver_kwargs={
...@@ -285,45 +258,69 @@ class RayGPUExecutor(ExecutorBase): ...@@ -285,45 +258,69 @@ class RayGPUExecutor(ExecutorBase):
lora_id=lora_id, lora_id=lora_id,
) )
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
return self._run_workers("list_loras") return self._run_workers("list_loras")
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
*args, *args,
driver_args: Optional[List[Any]] = None, driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False, use_ray_compiled_dag: bool = False,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers.""" """Runs the given method on all workers. Can be used in the following
ways:
- args/kwargs: All workers share the same args/kwargs
- args/kwargs and driver_args/driver_kwargs: Driver worker has
different args
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
if max_concurrent_workers: if max_concurrent_workers:
raise NotImplementedError( raise NotImplementedError(
"max_concurrent_workers is not supported yet.") "max_concurrent_workers is not supported yet.")
if driver_args is None:
driver_args = args if all_args is None else all_args[0]
if driver_kwargs is None:
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)
if use_ray_compiled_dag: if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single # Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it. # input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1) output_channels = self.forward_dag.execute(1)
else: else:
# Start the ray workers first. # Start the ray workers first.
ray_worker_outputs = [ ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs) worker.execute_method.remote(method, *worker_args,
for worker in self.workers **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
] ]
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# Start the driver worker after all the ray workers. # Start the driver worker after all the ray workers.
driver_worker_output = getattr(self.driver_worker, if not use_dummy_driver:
method)(*driver_args, **driver_kwargs) driver_worker_output = self.driver_worker.execute_method(
method, *driver_args, **driver_kwargs)
else:
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
# Get the results of the ray workers. # Get the results of the ray workers.
if self.workers: if self.workers:
if use_ray_compiled_dag: if use_ray_compiled_dag:
...@@ -381,11 +378,15 @@ class RayGPUExecutor(ExecutorBase): ...@@ -381,11 +378,15 @@ class RayGPUExecutor(ExecutorBase):
class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_executor = make_async(self.driver_worker.execute_method)
async def _run_workers_async( async def _run_workers_async(
self, self,
method: str, method: str,
*args, *args,
driver_args: Optional[List[Any]] = None, driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs, **kwargs,
) -> Any: ) -> Any:
...@@ -397,9 +398,8 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): ...@@ -397,9 +398,8 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
if driver_kwargs is None: if driver_kwargs is None:
driver_kwargs = kwargs driver_kwargs = kwargs
# Run the driver worker asynchronously. coros.append(
driver_executor = make_async(getattr(self.driver_worker, method)) self.driver_executor(method, *driver_args, **driver_kwargs))
coros.append(driver_executor(*driver_args, **driver_kwargs))
# Run the ray workers asynchronously. # Run the ray workers asynchronously.
for worker in self.workers: for worker in self.workers:
...@@ -427,7 +427,3 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): ...@@ -427,7 +427,3 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
# Only the driver worker returns the sampling results. # Only the driver worker returns the sampling results.
output = all_outputs[0] output = all_outputs[0]
return output return output
async def check_health_async(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
def check_block_size_valid(num_gpu_blocks, block_size, max_model_len) -> None:
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = block_size * num_gpu_blocks
if max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")
# Adapted from # Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py # https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM.""" """Logging configuration for vLLM."""
import datetime
import logging import logging
import os import os
import sys import sys
from functools import partial
from typing import Optional
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
...@@ -26,7 +29,7 @@ class NewLineFormatter(logging.Formatter): ...@@ -26,7 +29,7 @@ class NewLineFormatter(logging.Formatter):
_root_logger = logging.getLogger("vllm") _root_logger = logging.getLogger("vllm")
_default_handler = None _default_handler: Optional[logging.Handler] = None
def _setup_logger(): def _setup_logger():
...@@ -55,7 +58,76 @@ def init_logger(name: str): ...@@ -55,7 +58,76 @@ def init_logger(name: str):
# Use the same settings as above for root logger # Use the same settings as above for root logger
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG")) logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))
if VLLM_CONFIGURE_LOGGING: if VLLM_CONFIGURE_LOGGING:
if _default_handler is None:
raise ValueError(
"_default_handler is not set up. This should never happen!"
" Please open an issue on Github.")
logger.addHandler(_default_handler) logger.addHandler(_default_handler)
logger.propagate = False logger.propagate = False
return logger return logger
logger = init_logger(__name__)
def _trace_calls(log_path, root_dir, frame, event, arg=None):
if event in ['call', 'return']:
# Extract the filename, line number, function name, and the code object
filename = frame.f_code.co_filename
lineno = frame.f_lineno
func_name = frame.f_code.co_name
if not filename.startswith(root_dir):
# only log the functions in the vllm root_dir
return
# Log every function call or return
try:
last_frame = frame.f_back
if last_frame is not None:
last_filename = last_frame.f_code.co_filename
last_lineno = last_frame.f_lineno
last_func_name = last_frame.f_code.co_name
else:
# initial frame
last_filename = ""
last_lineno = 0
last_func_name = ""
with open(log_path, 'a') as f:
if event == 'call':
f.write(f"{datetime.datetime.now()} Call to"
f" {func_name} in {filename}:{lineno}"
f" from {last_func_name} in {last_filename}:"
f"{last_lineno}\n")
else:
f.write(f"{datetime.datetime.now()} Return from"
f" {func_name} in {filename}:{lineno}"
f" to {last_func_name} in {last_filename}:"
f"{last_lineno}\n")
except NameError:
# modules are deleted during shutdown
pass
return partial(_trace_calls, log_path, root_dir)
def enable_trace_function_call(log_file_path: str,
root_dir: Optional[str] = None):
"""
Enable tracing of every function call in code under `root_dir`.
This is useful for debugging hangs or crashes.
`log_file_path` is the path to the log file.
`root_dir` is the root directory of the code to trace. If None, it is the
vllm root directory.
Note that this call is thread-level, any threads calling this function
will have the trace enabled. Other threads will not be affected.
"""
logger.warning(
"VLLM_TRACE_FUNCTION is enabled. It will record every"
" function executed by Python. This will slow down the code. It "
"is suggested to be used for debugging hang or crashes only.")
logger.info(f"Trace frame log is saved to {log_file_path}")
if root_dir is None:
# by default, this is the vllm root directory
root_dir = os.path.dirname(os.path.dirname(__file__))
sys.settrace(partial(_trace_calls, log_file_path, root_dir))
...@@ -10,6 +10,12 @@ import torch.nn.functional as F ...@@ -10,6 +10,12 @@ import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
from vllm.lora.punica import add_lora, add_lora_slice, bgmv from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -18,18 +24,27 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -18,18 +24,27 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import (
split_tensor_along_last_dim)
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass
def _get_lora_device(base_layer: nn.Module) -> torch.device:
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
"""Returns the device for where to place the LoRA tensors."""
# unquantizedLinear
if hasattr(base_layer, "weight"):
return base_layer.weight.device
# GPTQ/AWQ/SqueezeLLM
elif hasattr(base_layer, "qweight"):
return base_layer.qweight.device
# marlin
elif hasattr(base_layer, "B"):
return base_layer.B.device
else:
raise ValueError(f"Unsupported base layer: {base_layer}")
def _apply_lora( def _apply_lora(
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: torch.Tensor, lora_a_stacked: torch.Tensor,
...@@ -268,12 +283,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -268,12 +283,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1 added_tokens_mask = x > self.base_layer.org_vocab_size - 1
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) embedding_len = self.indices_len[3]
indices = self.embeddings_indices[1][:embedding_len].view_as(x)
full_lora_a_embeddings = F.embedding( full_lora_a_embeddings = F.embedding(
x + indices, x + indices,
self.lora_a_stacked_2d, self.lora_a_stacked_2d,
) )
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) indices = self.embeddings_indices[0][:embedding_len].view_as(x)
full_output = self.base_layer.forward( full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask)) x.add_(indices * added_tokens_mask))
...@@ -302,6 +318,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -302,6 +318,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.input_size = self.base_layer.input_size
self.output_size = self.base_layer.output_size_per_partition
self.device = _get_lora_device(self.base_layer)
def create_lora_weights( def create_lora_weights(
self, self,
...@@ -312,17 +331,17 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -312,17 +331,17 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.weight.shape[1], self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
) )
self.lora_b_stacked = torch.zeros( self.lora_b_stacked = torch.zeros(
max_loras, max_loras,
1, 1,
self.base_layer.weight.shape[0], self.output_size,
lora_config.max_lora_rank, lora_config.max_lora_rank,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
) )
self.indices: Optional[torch.Tensor] = None self.indices: Optional[torch.Tensor] = None
...@@ -368,7 +387,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -368,7 +387,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def apply_weights(self, x: torch.Tensor, def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias) self.base_layer, x, bias)
_apply_lora( _apply_lora(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
...@@ -402,10 +421,6 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -402,10 +421,6 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
if self.base_layer.skip_bias_add else None) if self.base_layer.skip_bias_add else None)
return output, output_bias return output, output_bias
@property
def linear_weights(self):
return self.base_layer.linear_weights
@classmethod @classmethod
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List, lora_config: LoRAConfig, packed_modules_list: List,
...@@ -446,18 +461,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -446,18 +461,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.weight.shape[1], self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
) for _ in range(n_slices)) ) for _ in range(n_slices))
self.lora_b_stacked = tuple( self.lora_b_stacked = tuple(
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
self.base_layer.weight.shape[0] // 2, self.output_size // 2,
lora_config.max_lora_rank, lora_config.max_lora_rank,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
) for _ in range(n_slices)) ) for _ in range(n_slices))
self.indices: Optional[torch.Tensor] = None self.indices: Optional[torch.Tensor] = None
...@@ -505,7 +520,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -505,7 +520,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def apply_weights(self, x: torch.Tensor, def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias) self.base_layer, x, bias)
_apply_lora_packed_nslice( _apply_lora_packed_nslice(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
...@@ -623,25 +638,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -623,25 +638,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.weight.shape[1], self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
), ),
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.weight.shape[1], self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
), ),
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.weight.shape[1], self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
), ),
) )
self.lora_b_stacked = ( self.lora_b_stacked = (
...@@ -651,7 +666,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -651,7 +666,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.q_proj_shard_size, self.q_proj_shard_size,
lora_config.max_lora_rank, lora_config.max_lora_rank,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
), ),
torch.zeros( torch.zeros(
max_loras, max_loras,
...@@ -659,7 +674,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -659,7 +674,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.kv_proj_shard_size, self.kv_proj_shard_size,
lora_config.max_lora_rank, lora_config.max_lora_rank,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
), ),
torch.zeros( torch.zeros(
max_loras, max_loras,
...@@ -667,7 +682,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -667,7 +682,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.kv_proj_shard_size, self.kv_proj_shard_size,
lora_config.max_lora_rank, lora_config.max_lora_rank,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
), ),
) )
...@@ -746,7 +761,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -746,7 +761,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
def apply_weights(self, x: torch.Tensor, def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias) self.base_layer, x, bias)
_apply_lora_packed_nslice( _apply_lora_packed_nslice(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
...@@ -770,6 +785,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -770,6 +785,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None: def __init__(self, base_layer: RowParallelLinear) -> None:
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
self.input_size = self.base_layer.input_size_per_partition
self.output_size = self.base_layer.output_size
self.device = _get_lora_device(self.base_layer)
def create_lora_weights( def create_lora_weights(
self, self,
...@@ -781,20 +799,20 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -781,20 +799,20 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.weight.shape[1], self.input_size,
), ),
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
) )
self.lora_b_stacked = torch.zeros( self.lora_b_stacked = torch.zeros(
( (
max_loras, max_loras,
1, 1,
self.base_layer.weight.shape[0], self.output_size,
lora_config.max_lora_rank, lora_config.max_lora_rank,
), ),
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
) )
self.indices: Optional[torch.Tensor] = None self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None self.indices_len: Optional[List[int]] = None
...@@ -813,7 +831,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -813,7 +831,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.reset_lora(index) self.reset_lora(index)
if self.base_layer.tp_size > 1: if self.base_layer.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.base_layer.weight.shape[1] shard_size = self.input_size
start_idx = tensor_model_parallel_rank * shard_size start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :] lora_a = lora_a[start_idx:end_idx, :]
...@@ -838,7 +856,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -838,7 +856,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def apply_weights(self, x: torch.Tensor) -> torch.Tensor: def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x) self.base_layer, x)
_apply_lora( _apply_lora(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
...@@ -888,7 +906,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -888,7 +906,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
@property @property
def weight(self): def weight(self):
return self.base_layer.weight
return self.base_layer.weight if hasattr(
self.base_layer, "weight") else self.base_layer.qweight
@classmethod @classmethod
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
...@@ -939,9 +959,9 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -939,9 +959,9 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config: Optional[PretrainedConfig] = None, model_config: Optional[PretrainedConfig] = None,
) -> None: ) -> None:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
if 32000 < self.base_layer.vocab_size > 33024: if 32000 < self.base_layer.vocab_size > 128512:
raise ValueError("When using LoRA, vocab size must be " raise ValueError("When using LoRA, vocab size must be "
"32000 >= vocab_size <= 33024") "32000 >= vocab_size <= 128512")
self.lora_a_stacked = torch.zeros( self.lora_a_stacked = torch.zeros(
( (
max_loras, max_loras,
......
...@@ -33,7 +33,7 @@ class LoRALayerWeights: ...@@ -33,7 +33,7 @@ class LoRALayerWeights:
def optimize(self) -> "LoRALayerWeights": def optimize(self) -> "LoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b.""" """Optimize the LoRA by merging the scaling into lora_b."""
if self.scaling == 1: if self.scaling == 1:
return return self
self.lora_b *= self.scaling self.lora_b *= self.scaling
self.scaling = 1 self.scaling = 1
return self return self
......
...@@ -191,6 +191,7 @@ class LoRAModel: ...@@ -191,6 +191,7 @@ class LoRAModel:
def from_local_checkpoint( def from_local_checkpoint(
cls, cls,
lora_dir: str, lora_dir: str,
expected_lora_modules: List[str],
lora_model_id: Optional[int] = None, lora_model_id: Optional[int] = None,
device: str = "cuda", device: str = "cuda",
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
...@@ -206,6 +207,22 @@ class LoRAModel: ...@@ -206,6 +207,22 @@ class LoRAModel:
lora_dir, "new_embeddings.safetensors") lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir, new_embeddings_bin_file_path = os.path.join(lora_dir,
"new_embeddings.bin") "new_embeddings.bin")
with open(lora_config_path) as f:
config = json.load(f)
target_modules = config["target_modules"]
unexpected_modules = []
for module in target_modules:
# Compatible with more modules, such as:layers.11.self_attn.k_proj
part_name = module.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of expected_lora_modules
if unexpected_modules:
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
if os.path.isfile(lora_tensor_path): if os.path.isfile(lora_tensor_path):
tensors = safetensors.torch.load_file(lora_tensor_path) tensors = safetensors.torch.load_file(lora_tensor_path)
elif os.path.isfile(lora_bin_file_path): elif os.path.isfile(lora_bin_file_path):
...@@ -220,8 +237,6 @@ class LoRAModel: ...@@ -220,8 +237,6 @@ class LoRAModel:
elif os.path.isfile(new_embeddings_bin_file_path): elif os.path.isfile(new_embeddings_bin_file_path):
embeddings = torch.load(new_embeddings_bin_file_path) embeddings = torch.load(new_embeddings_bin_file_path)
with open(lora_config_path) as f:
config = json.load(f)
rank = config["r"] rank = config["r"]
lora_alpha = config["lora_alpha"] lora_alpha = config["lora_alpha"]
return cls.from_lora_tensors( return cls.from_lora_tensors(
......
...@@ -107,12 +107,12 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -107,12 +107,12 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
self._lora_manager: LoRAModelManager = lora_manager self._lora_manager: LoRAModelManager = lora_manager
return lora_manager.model return lora_manager.model
def set_active_loras(self, lora_requests: List[LoRARequest], def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None: lora_mapping: LoRAMapping) -> None:
self._apply_loras(lora_requests) self._apply_loras(lora_requests)
self._lora_manager.set_lora_mapping(lora_mapping) self._lora_manager.set_lora_mapping(lora_mapping)
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
loras_that_exist = self.list_loras() loras_that_exist = self.list_loras()
loras_map = { loras_map = {
lora_request.lora_int_id: lora_request lora_request.lora_int_id: lora_request
...@@ -136,8 +136,19 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -136,8 +136,19 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
try: try:
model = self._lora_manager.model
supported_lora_modules = model.supported_lora_modules
packed_modules_mapping = model.packed_modules_mapping
expected_lora_modules = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(
packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
lora = self._lora_model_cls.from_local_checkpoint( lora = self._lora_model_cls.from_local_checkpoint(
lora_request.lora_local_path, lora_request.lora_local_path,
expected_lora_modules,
lora_model_id=lora_request.lora_int_id, lora_model_id=lora_request.lora_int_id,
device="cpu", device="cpu",
dtype=self.lora_config.lora_dtype, dtype=self.lora_config.lora_dtype,
......
from typing import Optional, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
get_lm_format_enforcer_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor
async def get_guided_decoding_logits_processor(
guided_decoding_backend: str, request: Union[CompletionRequest,
ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]:
if guided_decoding_backend == 'outlines':
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
return await get_lm_format_enforcer_guided_decoding_logits_processor(
request, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")
from functools import lru_cache
from json import loads as json_loads
from typing import Optional, Union
from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
RegexParser, StringParser,
TokenEnforcerTokenizerData, UnionParser)
from lmformatenforcer.integrations.vllm import (
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor
async def get_lm_format_enforcer_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer)
character_level_parser: CharacterLevelParser
if request.guided_json:
schema = _normalize_json_schema_object(request.guided_json)
character_level_parser = JsonSchemaParser(schema)
elif request.guided_choice:
character_level_parser = UnionParser(
[StringParser(choice) for choice in request.guided_choice])
elif request.guided_regex:
character_level_parser = RegexParser(request.guided_regex)
elif request.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
elif (request.response_format is not None
and request.response_format.type == "json_object"):
character_level_parser = JsonSchemaParser(
None) # None means any json object
else:
return None
logits_processor = build_vllm_logits_processor(tokenizer_data,
character_level_parser)
return logits_processor
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
if isinstance(schema, str):
return json_loads(schema)
if isinstance(schema, dict):
return schema
if isinstance(schema, BaseModel):
return schema.model_json_schema()
@lru_cache
def _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
return build_vllm_token_enforcer_tokenizer_data(tokenizer)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment