Commit 51679bbd authored by zhuwenwen's avatar zhuwenwen
Browse files

resolve merge confilcts

parents 4095d0db 1af090b5
...@@ -33,11 +33,15 @@ async def generate(request: Request) -> Response: ...@@ -33,11 +33,15 @@ async def generate(request: Request) -> Response:
""" """
request_dict = await request.json() request_dict = await request.json()
prompt = request_dict.pop("prompt") prompt = request_dict.pop("prompt")
prefix_pos = request_dict.pop("prefix_pos", None)
stream = request_dict.pop("stream", False) stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
request_id = random_uuid() request_id = random_uuid()
results_generator = engine.generate(prompt, sampling_params, request_id) results_generator = engine.generate(prompt,
sampling_params,
request_id,
prefix_pos=prefix_pos)
# Streaming case # Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]: async def stream_results() -> AsyncGenerator[bytes, None]:
...@@ -74,12 +78,18 @@ if __name__ == "__main__": ...@@ -74,12 +78,18 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
app.root_path = args.root_path
uvicorn.run(app, uvicorn.run(app,
host=args.host, host=args.host,
port=args.port, port=args.port,
......
...@@ -3,6 +3,7 @@ from typing import List, Optional, Union ...@@ -3,6 +3,7 @@ from typing import List, Optional, Union
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.lora.request import LoRARequest
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -63,6 +64,7 @@ class LLM: ...@@ -63,6 +64,7 @@ class LLM:
max_context_len_to_capture: Maximum context len covered by CUDA graphs. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back When a sequence has context length larger than this, we fall back
to eager mode. to eager mode.
disable_custom_all_reduce: See ParallelConfig
""" """
def __init__( def __init__(
...@@ -81,6 +83,7 @@ class LLM: ...@@ -81,6 +83,7 @@ class LLM:
swap_space: int = 4, swap_space: int = 4,
enforce_eager: bool = False, enforce_eager: bool = False,
max_context_len_to_capture: int = 8192, max_context_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
...@@ -100,6 +103,7 @@ class LLM: ...@@ -100,6 +103,7 @@ class LLM:
swap_space=swap_space, swap_space=swap_space,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture, max_context_len_to_capture=max_context_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs, **kwargs,
) )
self.llm_engine = LLMEngine.from_engine_args(engine_args) self.llm_engine = LLMEngine.from_engine_args(engine_args)
...@@ -120,7 +124,9 @@ class LLM: ...@@ -120,7 +124,9 @@ class LLM:
prompts: Optional[Union[str, List[str]]] = None, prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
prefix_pos: Optional[Union[int, List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -134,7 +140,13 @@ class LLM: ...@@ -134,7 +140,13 @@ class LLM:
None, we use the default sampling parameters. None, we use the default sampling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs. use the tokenizer to convert the prompts to token IDs.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
Returns: Returns:
A list of `RequestOutput` objects containing the generated A list of `RequestOutput` objects containing the generated
...@@ -159,9 +171,14 @@ class LLM: ...@@ -159,9 +171,14 @@ class LLM:
prompt_token_ids) prompt_token_ids)
for i in range(num_requests): for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None prompt = prompts[i] if prompts is not None else None
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[ token_ids = None if prompt_token_ids is None else prompt_token_ids[
i] i]
self._add_request(prompt, sampling_params, token_ids) self._add_request(prompt,
sampling_params,
token_ids,
lora_request=lora_request,
prefix_pos=prefix_pos_i)
return self._run_engine(use_tqdm) return self._run_engine(use_tqdm)
def _add_request( def _add_request(
...@@ -169,10 +186,16 @@ class LLM: ...@@ -169,10 +186,16 @@ class LLM:
prompt: Optional[str], prompt: Optional[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]], prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id, prompt, sampling_params, self.llm_engine.add_request(request_id,
prompt_token_ids) prompt,
sampling_params,
prompt_token_ids,
lora_request=lora_request,
prefix_pos=prefix_pos)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm. # Initialize tqdm.
......
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import argparse import argparse
import asyncio import asyncio
import codecs
import json import json
import time from contextlib import asynccontextmanager
from http import HTTPStatus import os
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union import importlib
import inspect
from aioprometheus import MetricsMiddleware from aioprometheus import MetricsMiddleware
from aioprometheus.asgi.starlette import metrics from aioprometheus.asgi.starlette import metrics
import fastapi import fastapi
import uvicorn import uvicorn
from http import HTTPStatus
from fastapi import Request from fastapi import Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
...@@ -21,26 +19,33 @@ from fastapi.responses import JSONResponse, StreamingResponse, Response ...@@ -21,26 +19,33 @@ from fastapi.responses import JSONResponse, StreamingResponse, Response
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.metrics import add_global_metrics_labels from vllm.engine.metrics import add_global_metrics_labels
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.sampling_params import SamplingParams from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat: OpenAIServingChat = None
openai_serving_completion: OpenAIServingCompletion = None
logger = init_logger(__name__) logger = init_logger(__name__)
served_model = None
app = fastapi.FastAPI()
engine = None @asynccontextmanager
response_role = None async def lifespan(app: fastapi.FastAPI):
async def _force_log():
while True:
await asyncio.sleep(10)
await engine.do_log_stats()
if not engine_args.disable_log_stats:
asyncio.create_task(_force_log())
yield
app = fastapi.FastAPI(lifespan=lifespan)
def parse_args(): def parse_args():
...@@ -63,6 +68,13 @@ def parse_args(): ...@@ -63,6 +68,13 @@ def parse_args():
type=json.loads, type=json.loads,
default=["*"], default=["*"],
help="allowed headers") help="allowed headers")
parser.add_argument(
"--api-key",
type=str,
default=None,
help=
"If provided, the server will require this key to be presented in the header."
)
parser.add_argument("--served-model-name", parser.add_argument("--served-model-name",
type=str, type=str,
default=None, default=None,
...@@ -88,6 +100,22 @@ def parse_args(): ...@@ -88,6 +100,22 @@ def parse_args():
type=str, type=str,
default=None, default=None,
help="The file path to the SSL cert file") help="The file path to the SSL cert file")
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--middleware",
type=str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server using app.add_middleware(). "
)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args() return parser.parse_args()
...@@ -97,72 +125,10 @@ app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics ...@@ -97,72 +125,10 @@ app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics
app.add_route("/metrics", metrics) # Exposes HTTP metrics app.add_route("/metrics", metrics) # Exposes HTTP metrics
def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse:
return JSONResponse(ErrorResponse(message=message,
type="invalid_request_error").dict(),
status_code=status_code.value)
def load_chat_template(args, tokenizer):
if args.chat_template is not None:
try:
with open(args.chat_template, "r") as f:
chat_template = f.read()
except OSError:
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
chat_template = codecs.decode(args.chat_template, "unicode_escape")
tokenizer.chat_template = chat_template
logger.info(
f"Using supplied chat template:\n{tokenizer.chat_template}")
elif tokenizer.chat_template is not None:
logger.info(f"Using default chat template:\n{tokenizer.chat_template}")
else:
logger.warning("No chat template provided. Chat API will not work.")
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc): async def validation_exception_handler(_, exc):
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
async def check_model(request) -> Optional[JSONResponse]:
if request.model == served_model:
return
ret = create_error_response(
HTTPStatus.NOT_FOUND,
f"The model `{request.model}` does not exist.",
)
return ret
async def check_length(
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None
) -> Tuple[List[int], Optional[JSONResponse]]:
assert (not (prompt is None and prompt_ids is None)
and not (prompt is not None and prompt_ids is not None)
), "Either prompt or prompt_ids should be provided."
input_ids = prompt_ids if prompt_ids is not None else tokenizer(
prompt).input_ids
token_num = len(input_ids)
if request.max_tokens is None:
request.max_tokens = max_model_len - token_num
if token_num + request.max_tokens > max_model_len:
return input_ids, create_error_response(
HTTPStatus.BAD_REQUEST,
f"This model's maximum context length is {max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.",
)
else:
return input_ids, None
@app.get("/health") @app.get("/health")
...@@ -173,544 +139,37 @@ async def health() -> Response: ...@@ -173,544 +139,37 @@ async def health() -> Response:
@app.get("/v1/models") @app.get("/v1/models")
async def show_available_models(): async def show_available_models():
"""Show available models. Right now we only have one model.""" models = await openai_serving_chat.show_available_models()
model_cards = [ return JSONResponse(content=models.model_dump())
ModelCard(id=served_model,
root=served_model,
permission=[ModelPermission()])
]
return ModelList(data=model_cards)
def create_logprobs(
token_ids: List[int],
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
if num_output_top_logprobs:
logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None:
token_logprob = step_top_logprobs[token_id]
else:
token_logprob = None
token = tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
if num_output_top_logprobs:
logprobs.top_logprobs.append({
tokenizer.convert_ids_to_tokens(i): p
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
return logprobs
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest, async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request): raw_request: Request):
"""Completion API similar to OpenAI's API. generator = await openai_serving_chat.create_chat_completion(
request, raw_request)
See https://platform.openai.com/docs/api-reference/chat/create if isinstance(generator, ErrorResponse):
for the API specification. This API mimics the OpenAI ChatCompletion API. return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
NOTE: Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
try:
prompt = tokenizer.apply_chat_template(
conversation=request.messages,
tokenize=False,
add_generation_prompt=request.add_generation_prompt)
except Exception as e:
logger.error(f"Error in applying chat template from request: {str(e)}")
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.monotonic())
chunk_object_type = "chat.completion.chunk"
try:
spaces_between_special_tokens = request.spaces_between_special_tokens
sampling_params = SamplingParams(
n=request.n,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
top_p=request.top_p,
min_p=request.min_p,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
max_tokens=request.max_tokens,
best_of=request.best_of,
top_k=request.top_k,
ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
def get_role() -> str:
if request.add_generation_prompt:
return response_role
else:
return request.messages[-1]["role"]
async def completion_stream_generator() -> AsyncGenerator[str, None]:
# Send first response for each request.n (index) with the role
role = get_role()
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i, delta=DeltaMessage(role=role), finish_reason=None)
chunk = ChatCompletionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
# Send response to echo the input portion of the last message
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages, list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
if last_msg_content:
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=last_msg_content),
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
# Send response for each token for each request.n (index)
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
if finish_reason_sent[i]:
continue
if output.finish_reason is None:
# Send token-by-token response for each request.n
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
else:
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens + previous_num_tokens[i],
)
choice_data = ChatCompletionResponseStreamChoice(
index=i, delta=[], finish_reason=output.finish_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if final_usage is not None:
chunk.usage = final_usage
data = chunk.json(exclude_unset=True,
exclude_none=True,
ensure_ascii=False)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
async def completion_full_generator():
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
role = get_role()
for output in final_res.outputs:
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role=role, content=output.text),
finish_reason=output.finish_reason,
)
choices.append(choice_data)
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages, list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
for choice in choices:
full_message = last_msg_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
return response
# Streaming response
if request.stream: if request.stream:
return StreamingResponse(completion_stream_generator(), return StreamingResponse(content=generator,
media_type="text/event-stream") media_type="text/event-stream")
else: else:
return await completion_full_generator() return JSONResponse(content=generator.model_dump())
@app.post("/v1/completions") @app.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request): async def create_completion(request: CompletionRequest, raw_request: Request):
"""Completion API similar to OpenAI's API. generator = await openai_serving_completion.create_completion(
request, raw_request)
See https://platform.openai.com/docs/api-reference/completions/create if isinstance(generator, ErrorResponse):
for the API specification. This API mimics the OpenAI Completion API. return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
NOTE: Currently we do not support the following features:
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
# OpenAI API supports echoing the prompt when max_tokens is 0.
echo_without_generation = request.echo and request.max_tokens == 0
if request.suffix is not None:
# The language models we currently support do not support suffix.
return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported")
if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
use_token_ids = False
if isinstance(request.prompt, list):
if len(request.prompt) == 0:
return create_error_response(HTTPStatus.BAD_REQUEST,
"please provide at least one prompt")
first_element = request.prompt[0]
if isinstance(first_element, int):
use_token_ids = True
prompt = request.prompt
elif isinstance(first_element, (str, list)):
# TODO: handles multiple prompt case in list[list[int]]
if len(request.prompt) > 1:
return create_error_response(
HTTPStatus.BAD_REQUEST,
"multiple prompts in a batch is not currently supported")
use_token_ids = not isinstance(first_element, str)
prompt = request.prompt[0]
else:
prompt = request.prompt
if use_token_ids:
_, error_check_ret = await check_length(request, prompt_ids=prompt)
else:
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
created_time = int(time.monotonic())
try:
spaces_between_special_tokens = request.spaces_between_special_tokens
sampling_params = SamplingParams(
n=request.n,
best_of=request.best_of,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
min_p=request.min_p,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens
if not echo_without_generation else 1,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
prompt_logprobs=request.logprobs if request.echo else None,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
if use_token_ids:
result_generator = engine.generate(None,
sampling_params,
request_id,
prompt_token_ids=prompt)
else:
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
def create_stream_response_json(
index: int,
text: str,
logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None,
usage: Optional[UsageInfo] = None,
) -> str:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
logprobs=logprobs,
finish_reason=finish_reason,
)
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
if usage is not None:
response.usage = usage
response_json = response.json(exclude_unset=True, ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
has_echoed = [False] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
token_ids = output.token_ids[previous_num_tokens[i]:]
if request.logprobs is not None:
top_logprobs = output.logprobs[previous_num_tokens[i]:]
else:
top_logprobs = None
offsets = len(previous_texts[i])
if request.echo and not has_echoed[i]:
if not echo_without_generation:
delta_text = res.prompt + delta_text
token_ids = res.prompt_token_ids + token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs + top_logprobs
else: # only just return the prompt
delta_text = res.prompt
token_ids = res.prompt_token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
if request.logprobs is not None:
logprobs = create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=offsets,
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
response_json = create_stream_response_json(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
logprobs = (LogProbs()
if request.logprobs is not None else None)
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = create_stream_response_json(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
usage=final_usage,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if stream:
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream")
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
if request.logprobs is not None:
if not echo_without_generation:
token_ids = output.token_ids
top_logprobs = output.logprobs
if request.echo:
token_ids = prompt_token_ids + token_ids
top_logprobs = prompt_logprobs + top_logprobs
else:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
logprobs = create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
if not echo_without_generation:
output_text = output.text
if request.echo:
output_text = prompt_text + output_text
else:
output_text = prompt_text
choice_data = CompletionResponseChoice(
index=output.index,
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream: if request.stream:
# When user requests streaming but we don't stream, we still need to return StreamingResponse(content=generator,
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream") media_type="text/event-stream")
else:
return response return JSONResponse(content=generator.model_dump())
if __name__ == "__main__": if __name__ == "__main__":
...@@ -724,6 +183,29 @@ if __name__ == "__main__": ...@@ -724,6 +183,29 @@ if __name__ == "__main__":
allow_headers=args.allowed_headers, allow_headers=args.allowed_headers,
) )
if token := os.environ.get("VLLM_API_KEY") or args.api_key:
@app.middleware("http")
async def authentication(request: Request, call_next):
if not request.url.path.startswith("/v1"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token:
return JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return await call_next(request)
for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
raise ValueError(
f"Invalid middleware {middleware}. Must be a function or a class."
)
logger.info(f"args: {args}") logger.info(f"args: {args}")
if args.served_model_name is not None: if args.served_model_name is not None:
...@@ -731,23 +213,17 @@ if __name__ == "__main__": ...@@ -731,23 +213,17 @@ if __name__ == "__main__":
else: else:
served_model = args.model served_model = args.model
response_role = args.response_role
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config()) openai_serving_chat = OpenAIServingChat(engine, served_model,
max_model_len = engine_model_config.max_model_len args.response_role,
args.chat_template)
# A separate tokenizer to map token IDs to strings. openai_serving_completion = OpenAIServingCompletion(engine, served_model)
tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code)
load_chat_template(args, tokenizer)
# Register labels for metrics # Register labels for metrics
add_global_metrics_labels(model_name=engine_args.model) add_global_metrics_labels(model_name=engine_args.model)
app.root_path = args.root_path
uvicorn.run(app, uvicorn.run(app,
host=args.host, host=args.host,
port=args.port, port=args.port,
......
...@@ -6,6 +6,7 @@ from typing import Dict, List, Literal, Optional, Union ...@@ -6,6 +6,7 @@ from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
...@@ -13,7 +14,7 @@ class ErrorResponse(BaseModel): ...@@ -13,7 +14,7 @@ class ErrorResponse(BaseModel):
message: str message: str
type: str type: str
param: Optional[str] = None param: Optional[str] = None
code: Optional[str] = None code: int
class ModelPermission(BaseModel): class ModelPermission(BaseModel):
...@@ -77,6 +78,30 @@ class ChatCompletionRequest(BaseModel): ...@@ -77,6 +78,30 @@ class ChatCompletionRequest(BaseModel):
echo: Optional[bool] = False echo: Optional[bool] = False
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
min_p: Optional[float] = 0.0 min_p: Optional[float] = 0.0
include_stop_str_in_output: Optional[bool] = False
length_penalty: Optional[float] = 1.0
def to_sampling_params(self) -> SamplingParams:
return SamplingParams(
n=self.n,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
min_p=self.min_p,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos,
use_beam_search=self.use_beam_search,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
)
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
...@@ -106,6 +131,34 @@ class CompletionRequest(BaseModel): ...@@ -106,6 +131,34 @@ class CompletionRequest(BaseModel):
spaces_between_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
min_p: Optional[float] = 0.0 min_p: Optional[float] = 0.0
include_stop_str_in_output: Optional[bool] = False
length_penalty: Optional[float] = 1.0
def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0
return SamplingParams(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1,
logprobs=self.logprobs,
use_beam_search=self.use_beam_search,
prompt_logprobs=self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=(self.spaces_between_special_tokens),
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
)
class LogProbs(BaseModel): class LogProbs(BaseModel):
...@@ -144,7 +197,7 @@ class CompletionStreamResponse(BaseModel): ...@@ -144,7 +197,7 @@ class CompletionStreamResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[CompletionResponseStreamChoice] choices: List[CompletionResponseStreamChoice]
usage: Optional[UsageInfo] usage: Optional[UsageInfo] = Field(default=None)
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
...@@ -184,5 +237,4 @@ class ChatCompletionStreamResponse(BaseModel): ...@@ -184,5 +237,4 @@ class ChatCompletionStreamResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseStreamChoice] choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field( usage: Optional[UsageInfo] = Field(default=None)
default=None, description="data about request and response")
import time
import codecs
from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Union
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing
logger = init_logger(__name__)
class OpenAIServingChat(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
served_model: str,
response_role: str,
chat_template=None):
super().__init__(engine=engine, served_model=served_model)
self.response_role = response_role
self._load_chat_template(chat_template)
async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Request
) -> Union[ErrorResponse, AsyncGenerator[str, None],
ChatCompletionResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return self.create_error_response(
"logit_bias is not currently supported")
try:
prompt = self.tokenizer.apply_chat_template(
conversation=request.messages,
tokenize=False,
add_generation_prompt=request.add_generation_prompt)
except Exception as e:
logger.error(
f"Error in applying chat template from request: {str(e)}")
return self.create_error_response(str(e))
request_id = f"cmpl-{random_uuid()}"
try:
token_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt)
sampling_params = request.to_sampling_params()
except ValueError as e:
return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt, sampling_params,
request_id, token_ids)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id)
else:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id)
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
return self.response_role
else:
return request.messages[-1].role
async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str
) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
model_name = request.model
created_time = int(time.monotonic())
chunk_object_type = "chat.completion.chunk"
# Send first response for each request.n (index) with the role
role = self.get_chat_request_role(request)
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i, delta=DeltaMessage(role=role), finish_reason=None)
chunk = ChatCompletionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Send response to echo the input portion of the last message
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages, list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
if last_msg_content:
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=last_msg_content),
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Send response for each token for each request.n (index)
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
if finish_reason_sent[i]:
continue
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
else:
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens + previous_num_tokens[i],
)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
finish_reason=output.finish_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if final_usage is not None:
chunk.usage = final_usage
data = chunk.model_dump_json(exclude_unset=True,
exclude_none=True)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Request,
result_generator: AsyncIterator[RequestOutput],
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = request.model
created_time = int(time.monotonic())
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
return self.create_error_response("Client disconnected")
final_res = res
assert final_res is not None
choices = []
role = self.get_chat_request_role(request)
for output in final_res.outputs:
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role=role, content=output.text),
finish_reason=output.finish_reason,
)
choices.append(choice_data)
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages, list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
for choice in choices:
full_message = last_msg_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
return response
def _load_chat_template(self, chat_template):
if chat_template is not None:
try:
with open(chat_template, "r") as f:
self.tokenizer.chat_template = f.read()
except OSError:
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
self.tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape")
logger.info(
f"Using supplied chat template:\n{self.tokenizer.chat_template}"
)
elif self.tokenizer.chat_template is not None:
logger.info(
f"Using default chat template:\n{self.tokenizer.chat_template}"
)
else:
logger.warning(
"No chat template provided. Chat API will not work.")
import asyncio
import time
from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine
from .protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
LogProbs,
UsageInfo,
)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing
logger = init_logger(__name__)
TypeTokenIDs = list[int]
TypeTopLogProbs = List[Optional[dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
async def completion_stream_generator(
request: CompletionRequest,
raw_request: Request,
on_abort,
result_generator: AsyncIterator[tuple[int, RequestOutput]],
create_logprobs_fn: TypeCreateLogProbsFn,
request_id: str,
created_time: int,
model_name: str,
num_prompts: int,
) -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts
async for prompt_idx, res in result_generator:
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await on_abort(f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
for output in res.outputs:
i = output.index + prompt_idx * request.n
# TODO(simon): optimize the performance by avoiding full text O(n^2) sending.
if request.echo and request.max_tokens == 0:
# only return the prompt
delta_text = res.prompt
delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
elif request.echo and request.max_tokens > 0 and not has_echoed[i]:
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = res.prompt_token_ids + output.token_ids
top_logprobs = res.prompt_logprobs + (output.logprobs or [])
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs is not None:
assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested"
logprobs = create_logprobs_fn(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
)
]).model_dump_json(exclude_unset=True)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None: # return final usage
logprobs = LogProbs() if request.logprobs is not None else None
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
)
],
usage=final_usage,
).model_dump_json(exclude_unset=True)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
def parse_prompt_format(prompt) -> tuple[bool, list]:
# get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays."
prompt_is_tokens = False
prompts = [prompt] # case 1: a string
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
elif isinstance(prompt[0], str):
prompt_is_tokens = False
prompts = prompt # case 2: array of strings
elif isinstance(prompt[0], int):
prompt_is_tokens = True
prompts = [prompt] # case 3: array of tokens
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
prompt_is_tokens = True
prompts = prompt # case 4: array of token arrays
else:
raise ValueError(
"prompt must be a string, array of strings, array of tokens, or array of token arrays"
)
return prompt_is_tokens, prompts
def request_output_to_completion_response(
final_res_batch: list[RequestOutput],
request: CompletionRequest,
create_logprobs_fn: TypeCreateLogProbsFn,
request_id: str,
created_time: int,
model_name: str,
) -> CompletionResponse:
choices = []
num_prompt_tokens = 0
num_generated_tokens = 0
for final_res in final_res_batch:
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = prompt_logprobs + output.logprobs
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
top_logprobs = output.logprobs
output_text = output.text
if request.logprobs is not None:
logprobs = create_logprobs_fn(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=len(choices),
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
def merge_async_iterators(*iterators):
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue = asyncio.Queue()
finished = [False] * len(iterators)
async def producer(i, iterator):
async for item in iterator:
await queue.put((i, item))
finished[i] = True
_tasks = [
asyncio.create_task(producer(i, iterator))
for i, iterator in enumerate(iterators)
]
async def consumer():
while not all(finished) or not queue.empty():
item = await queue.get()
yield item
await asyncio.gather(*_tasks)
return consumer()
class OpenAIServingCompletion(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, served_model: str):
super().__init__(engine=engine, served_model=served_model)
async def create_completion(self, request: CompletionRequest,
raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# Return error for unsupported features.
if request.suffix is not None:
return self.create_error_response(
"suffix is not currently supported")
if request.logit_bias is not None and len(request.logit_bias) > 0:
return self.create_error_response(
"logit_bias is not currently supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.monotonic())
# Schedule the request and get the result generator.
generators = []
try:
sampling_params = request.to_sampling_params()
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
else:
input_ids = self._validate_prompt_and_tokenize(
request, prompt=prompt)
generators.append(
self.engine.generate(None,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=input_ids))
except ValueError as e:
return self.create_error_response(str(e))
result_generator: AsyncIterator[tuple[
int, RequestOutput]] = merge_async_iterators(*generators)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
# Streaming response
if stream:
return completion_stream_generator(request,
raw_request,
self.engine.abort,
result_generator,
self._create_logprobs,
request_id,
created_time,
model_name,
num_prompts=len(prompts))
# Non-streaming response
final_res_batch: RequestOutput = [None] * len(prompts)
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
response = request_output_to_completion_response(
final_res_batch, request, self._create_logprobs, request_id,
created_time, model_name)
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
if request.stream:
response_json = response.model_dump_json()
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return fake_stream_generator()
return response
import asyncio
from http import HTTPStatus
from typing import Dict, List, Optional, Union
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (CompletionRequest,
ChatCompletionRequest,
ErrorResponse, LogProbs,
ModelCard, ModelList,
ModelPermission)
logger = init_logger(__name__)
class OpenAIServing:
def __init__(self, engine: AsyncLLMEngine, served_model: str):
self.engine = engine
self.served_model = served_model
self.max_model_len = 0
self.tokenizer = None
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None
if event_loop is not None and event_loop.is_running(
): # If the current is instanced by Ray Serve, there is already a running event loop
event_loop.create_task(self._post_init())
else: # When using single vLLM without engine_use_ray
asyncio.run(self._post_init())
async def _post_init(self):
engine_model_config = await self.engine.get_model_config()
self.max_model_len = engine_model_config.max_model_len
# A separate tokenizer to map token IDs to strings.
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code)
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(id=self.served_model,
root=self.served_model,
permission=[ModelPermission()])
]
return ModelList(data=model_cards)
def _create_logprobs(
self,
token_ids: List[int],
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
if num_output_top_logprobs:
logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None:
token_logprob = step_top_logprobs[token_id]
else:
token_logprob = None
token = self.tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
if num_output_top_logprobs:
logprobs.top_logprobs.append({
self.tokenizer.convert_ids_to_tokens(i): p
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
return logprobs
def create_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(message=message,
type=err_type,
code=status_code.value)
async def _check_model(self, request) -> Optional[ErrorResponse]:
if request.model == self.served_model:
return
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None) -> List[int]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
prompt).input_ids
token_num = len(input_ids)
if request.max_tokens is None:
request.max_tokens = self.max_model_len - token_num
if token_num + request.max_tokens > self.max_model_len:
raise ValueError(
f"This model's maximum context length is {self.max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", )
else:
return input_ids
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather,
)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear,
QKVParallelLinear,
MergedColumnParallelLinear)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim
if TYPE_CHECKING:
pass
def _apply_lora(
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
indices: torch.Tensor,
output: torch.Tensor,
):
"""Applies lora to each input.
This method applies all loras to each input. It uses the
indices vector to determine which lora yields the
correct output. An index of -1 means no lora should be
applied. This method adds the final lora results to the
output.
Input shapes:
x: (batch_size, hidden_dim)
lora_a_stacked: (num_loras, lora_rank, hidden_dim)
lora_b_stacked: (num_loras, output_dim, lora_rank)
indices: (batch_size)
output: (batch_size, output_dim)
"""
org_output = output
x = x.view(-1, x.shape[-1])
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
return output.view_as(org_output)
def _apply_lora_packed_nslice(
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
):
"""Applies lora to each input.
This method applies all loras to each input. It uses the
indices vector to determine which lora yields the
correct output. An index of -1 means no lora should be
applied. This method adds the final lora results to the
output.
This method is used for layers that are composed of multiple sublayers
(slices) packed together.
Input shapes:
x: (batch_size, hidden_dim)
lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...), where n is number of slices
"""
org_output = output
x = x.view(-1, x.shape[-1])
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx in range(len(output_slices)):
add_lora_slice(output, x, lora_a_stacked[slice_idx],
lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
output_slices[slice_idx])
offset_left += output_slices[slice_idx]
return output.view_as(org_output)
@dataclass
class LoRAMapping:
# Per every token in input_ids:
index_mapping: Tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
class BaseLayerWithLoRA(nn.Module):
def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig,
model_config: PretrainedConfig) -> None:
"""Initializes lora matrices."""
...
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
...
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
"""Overwrites lora tensors at index."""
...
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
indices_len: List[int],
):
"""Sets the mapping indices."""
...
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
lora_vocab_start_idx = self.base_layer.org_vocab_size
weights_idx = None
if self.base_layer.vocab_end_index > lora_vocab_start_idx:
# We can start adding lora weights
weights_idx = max(
lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
self.embeddings_slice = (self.base_layer.vocab_start_index -
self.base_layer.org_vocab_size +
weights_idx,
self.base_layer.vocab_end_index -
self.base_layer.org_vocab_size)
self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
self.embeddings_weights.fill_(0)
else:
self.embeddings_slice = None
self.embeddings_weights = None
self.embeddings_tensors = torch.zeros(
(
max_loras,
lora_config.lora_extra_vocab_size,
self.base_layer.embedding_dim,
),
dtype=self.base_layer.weight.dtype,
device=self.base_layer.weight.device,
)
self.lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.org_vocab_size +
lora_config.lora_extra_vocab_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
self.base_layer.embedding_dim,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_a_stacked_2d = self.lora_a_stacked.view(
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2],
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.embeddings_indices = None
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index, :embeddings_tensor.shape[0], :embeddings_tensor.
shape[1]].copy_(embeddings_tensor, non_blocking=True)
if self.embeddings_slice is not None:
# TODO(yard1): Optimize this copy, we don't need to copy
# everything, just the modified part
embeddings = self.embeddings_tensors.view(
self.embeddings_tensors.shape[0] *
self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2]
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
self.embeddings_indices = embeddings_indices
self.indices_len = indices_len
def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x)
full_lora_a_embeddings = F.embedding(
x + indices,
self.lora_a_stacked_2d,
)
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x)
full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask))
full_output_org = full_output
if full_output.ndim == 3:
full_output = full_output.view(
full_output.shape[0] * full_output.shape[1], -1)
if full_lora_a_embeddings.ndim == 3:
full_lora_a_embeddings = full_lora_a_embeddings.view(
full_lora_a_embeddings.shape[0] *
full_lora_a_embeddings.shape[1], -1)
bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
return full_output.view_as(full_output_org)
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__()
self.base_layer = base_layer
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_a_stacked = torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_b_stacked = torch.zeros(
max_loras,
1,
self.base_layer.weight.shape[0],
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.output_dim = self.lora_b_stacked.shape[1]
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
self.indices_len = indices_len
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias)
_apply_lora(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
)
return output
def forward(self, input_):
"""Forward of ColumnParallelLinear
Args:
input_: Tensor whose last dimension is `input_size`.
Returns:
- output
- bias
"""
bias = (self.base_layer.bias
if not self.base_layer.skip_bias_add else None)
# Matrix multiply.
output_parallel = self.apply_weights(input_, bias)
if self.base_layer.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = (self.base_layer.bias
if self.base_layer.skip_bias_add else None)
return output, output_bias
@property
def linear_weights(self):
return self.base_layer.linear_weights
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
packed together (eg. gate_proj + up_proj -> gate_up_proj).
This means we have 2 LoRAs, each applied to one half of the layer.
Both slices must have the same size.
"""
def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
super().__init__(base_layer)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
n_slices = 2
if not (len(self.base_layer.output_sizes) == n_slices
and self.base_layer.output_sizes[0]
== self.base_layer.output_sizes[1]):
raise ValueError(
"LoRAColumnParallelLinear2Slice requires 2 slices with "
"the same size.")
self.tp_size = get_tensor_model_parallel_world_size()
self.lora_a_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
) for _ in range(n_slices))
self.lora_b_stacked = tuple(
torch.zeros(
max_loras,
1,
self.base_layer.weight.shape[0] // 2,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
) for _ in range(n_slices))
self.indices: Optional[torch.Tensor] = None
self.output_dim = self.lora_b_stacked[0].shape[2]
def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
self.lora_a_stacked[1][index] = 0
self.lora_b_stacked[0][index] = 0
self.lora_b_stacked[1][index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[0][:,
start_idx:end_idx], lora_b[1][:,
start_idx:end_idx]
if lora_a[0] is not None:
self.lora_a_stacked[0][
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
lora_a[0].T, non_blocking=True)
self.lora_b_stacked[0][
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
lora_b[0].T, non_blocking=True)
if lora_a[1] is not None:
self.lora_a_stacked[1][
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
lora_a[1].T, non_blocking=True)
self.lora_b_stacked[1][
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
(self.output_dim, self.output_dim),
)
return output
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices)
packed together in qkv proj fashion
(q_proj + k_proj + v_proj -> qkv_proj).
This means we have 3 LoRAs, each applied to one slice of the layer.
Q slice may have different shape than K and V slices (which both have
the same shape).
"""
def __init__(self, base_layer: QKVParallelLinear) -> None:
super().__init__(base_layer)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size)
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
self.base_layer.head_size)
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
# q, k, v
self.lora_a_stacked = (
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
)
self.lora_b_stacked = (
torch.zeros(
max_loras,
1,
self.q_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
)
self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
self.kv_proj_shard_size)
self.packed_indices: Optional[torch.Tensor] = None
self.standard_indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
self.lora_b_stacked[0][index] = 0
self.lora_a_stacked[1][index] = 0
self.lora_b_stacked[1][index] = 0
self.lora_a_stacked[2][index] = 0
self.lora_b_stacked[2][index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.tp_size > 1:
if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
self.lora_b_stacked[0][
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
lora_b_q.T, non_blocking=True)
if lora_b[1] is not None:
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
self.lora_b_stacked[1][
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
lora_b_k.T, non_blocking=True)
if lora_b[2] is not None:
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
self.lora_b_stacked[2][
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
lora_b_v.T, non_blocking=True)
else:
if lora_b[0] is not None:
self.lora_b_stacked[0][
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
lora_b[0].T, non_blocking=True)
if lora_b[1] is not None:
self.lora_b_stacked[1][
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)
if lora_b[2] is not None:
self.lora_b_stacked[2][
index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
lora_b[2].T, non_blocking=True)
if lora_a[0] is not None:
self.lora_a_stacked[0][
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
lora_a[0].T, non_blocking=True)
if lora_a[1] is not None:
self.lora_a_stacked[1][
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
lora_a[1].T, non_blocking=True)
if lora_a[2] is not None:
self.lora_a_stacked[2][
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
lora_a[2].T, non_blocking=True)
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
self.output_slices,
)
return output
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None:
super().__init__()
self.base_layer = base_layer
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_a_stacked = torch.zeros(
(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
self.base_layer.weight.shape[0],
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.base_layer.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.base_layer.weight.shape[1]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :]
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
self.indices_len = indices_len
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x)
_apply_lora(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
)
return output
def forward(self, input_):
"""Forward of RowParallelLinear
Args:
input_: tensor whose last dimension is `input_size`. If
`input_is_parallel` is set, then the last dimension
is `input_size // tp_size`.
Returns:
- output
- bias
"""
# Set up backprop all-reduce.
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
# TODO: simplify code below
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.apply_weights(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.base_layer.skip_bias_add:
output = (output_ + self.base_layer.bias
if self.base_layer.bias is not None else output_)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
return output, output_bias
@property
def weight(self):
return self.base_layer.weight
class SamplerWithLoRA(BaseLayerWithLoRA):
def __init__(
self,
base_layer: Sampler,
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
) -> None:
super().__init__()
self.base_layer = base_layer
self.hidden_size = hidden_size
self.dtype = dtype
self.device = device
@property
def vocab_size(self):
return self.base_layer.vocab_size
@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size
@property
def include_gpu_probs_tensor(self):
return self.base_layer.include_gpu_probs_tensor
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
if 32000 < self.base_layer.vocab_size > 33024:
raise ValueError(
"When using LoRA, vocab size must be 32000 >= vocab_size <= 33024"
)
self.lora_a_stacked = torch.zeros(
(
max_loras,
1,
lora_config.max_lora_rank,
self.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
# Pad for kernel compatibility
math.ceil(self.base_layer.vocab_size /
lora_config.lora_vocab_padding_size) *
lora_config.lora_vocab_padding_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.embeddings_tensors = torch.full(
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
fill_value=float("-inf"),
dtype=self.dtype,
device=self.device,
)
self.indices = None
self.indices_padded = None
self.indices_len = None
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = float("-inf")
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index, :embeddings_tensor.shape[0], :embeddings_tensor.
shape[1], ] = embeddings_tensor
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = sampler_indices
self.indices_padded = sampler_indices_padded
self.indices_len = indices_len
def _get_logits(
self,
hidden_states: torch.Tensor,
embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
if logits is None:
return None
lora_logits = torch.empty(
self.embeddings_tensors.shape[0] + 1,
self.embeddings_tensors.shape[1],
hidden_states.shape[0],
dtype=self.embeddings_tensors.dtype,
device=self.embeddings_tensors.device,
)
torch.matmul(self.embeddings_tensors,
hidden_states.T,
out=lora_logits[:-1])
lora_logits[-1] = float("-inf")
lora_logits = lora_logits.mT
lora_logits = (lora_logits.reshape(
lora_logits.shape[0] * lora_logits.shape[1],
lora_logits.shape[2],
).index_select(0,
self.indices_padded[:self.indices_len[2]]).nan_to_num_(
nan=float("-inf"),
posinf=float("inf"),
neginf=float("-inf")))
logits[:,
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1]] = lora_logits
_apply_lora(
hidden_states,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[1]],
logits,
)
# Remove paddings in vocab (if any).
logits = logits[:, :self.base_layer.vocab_size]
return logits
def forward(self, *args, **kwargs):
return type(self.base_layer).forward(self, *args, **kwargs)
def from_layer(
layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA:
supported_layer_types = {
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLora,
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
RowParallelLinear: RowParallelLinearWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
return layer
def from_layer_sampler(
layer: Sampler,
lm_head: ParallelLMHead,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> SamplerWithLoRA:
ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype,
lm_head.weight.device)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
from typing import List, Optional
import torch
from vllm.utils import in_wsl
class LoRALayerWeights:
"""LoRA weights for a layer composed of two low rank matrixes."""
def __init__(
self,
module_name: str,
rank: int,
lora_alpha: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor] = None,
scaling: Optional[float] = None,
) -> None:
self.module_name = module_name
self.rank = rank
self.lora_alpha = lora_alpha
self.lora_a = lora_a
self.lora_b = lora_b
self.embeddings_tensor = embeddings_tensor
if scaling is None:
self.scaling = self.lora_alpha / self.rank
else:
self.scaling = scaling
def optimize(self) -> "LoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
if self.scaling == 1:
return
self.lora_b *= self.scaling
self.scaling = 1
return self
@property
def input_dim(self) -> int:
return self.lora_a.shape[0]
@property
def output_dim(self) -> int:
return self.lora_b.shape[1]
@property
def is_packed(self) -> bool:
return False
@property
def extra_vocab_size(self) -> int:
return self.embeddings_tensor.shape[
0] if self.embeddings_tensor is not None else 0
@classmethod
def create_dummy_lora_weights(
cls,
module_name: str,
input_dim: int,
output_dim: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and not in_wsl()
lora_a = torch.zeros([input_dim, rank],
dtype=dtype,
device=device,
pin_memory=pin_memory)
lora_b = torch.zeros([rank, output_dim],
dtype=dtype,
device=device,
pin_memory=pin_memory)
embeddings_tensor = torch.rand(
10,
embeddings_tensor_dim,
dtype=dtype,
device=device,
pin_memory=pin_memory) if embeddings_tensor_dim else None
return cls(
module_name,
rank=rank,
lora_alpha=1,
lora_a=lora_a,
lora_b=lora_b,
embeddings_tensor=embeddings_tensor,
)
class PackedLoRALayerWeights(LoRALayerWeights):
"""LoRA used for packed layers (eg. qkv_proj)."""
def __init__(
self,
module_name: str,
rank: int,
lora_alphas: List[int],
lora_a: List[torch.Tensor],
lora_b: List[torch.Tensor],
scaling: Optional[List[float]] = None,
) -> None:
super().__init__(
module_name=module_name,
rank=rank,
lora_alpha=0,
lora_a=lora_a,
lora_b=lora_b,
scaling=scaling,
embeddings_tensor=None,
)
self.lora_alphas = lora_alphas
if scaling is None:
self.scaling = [
lora_alpha / self.rank for lora_alpha in self.lora_alphas
]
@classmethod
def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
"""
first_lora = next(lora for lora in loras if lora is not None)
for lora in loras:
if lora is None:
continue
lora.optimize()
rank = first_lora.rank
module_name = first_lora.module_name
obj = cls(
module_name,
rank,
[lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras],
scaling=[1 if lora is not None else None for lora in loras])
return obj
def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)):
if self.scaling[i] == 1 or self.lora_b[i] is None:
continue
self.lora_b[i] *= self.scaling[i]
self.scaling[i] = 1
return self
@property
def input_dim(self) -> int:
raise NotImplementedError()
@property
def output_dim(self) -> int:
raise NotImplementedError()
@property
def is_packed(self) -> bool:
return True
import copy
import json
import logging
import math
import os
import re
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type,
Union)
import safetensors.torch
import torch
from torch import nn
from vllm.config import LoRAConfig
from vllm.utils import LRUCache, in_wsl
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, from_layer, from_layer_sampler
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
logger = logging.getLogger(__name__)
# TODO: The mappings below should be moved to individual model classes.
PACKED_MODULES_CFG = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
TARGET_MODULES_QKV = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
EMBEDDING_MODULES = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
EMBEDDING_PADDING_MODULES = ["lm_head"]
_GLOBAL_LORA_ID = 0
def convert_mapping(
mapping: LoRAMapping, lora_index_to_id: List[Optional[int]],
max_loras: int, vocab_size: int, extra_vocab_size: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
indices_len: List of lengths of the above tensors.
"""
indices = list(mapping.index_mapping).copy()
embedding_indices = indices.copy()
lora_indices = indices.copy()
prompt_mapping = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(indices[i])
if indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if indices[i] > 0 else 0
indices[i] = i
lora_indices[i] = lora_idx
indices = torch.tensor([indices, lora_indices, embedding_indices],
dtype=torch.long,
device="cuda")
prompt_mapping = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size)
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = (
torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
indices_len = (base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1])
return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, indices_len)
def get_lora_id():
global _GLOBAL_LORA_ID
_GLOBAL_LORA_ID += 1
return _GLOBAL_LORA_ID
class LoRAModel:
"""A LoRA fine-tuned model."""
def __init__(
self,
lora_model_id: int,
rank: int,
loras: Dict[str, LoRALayerWeights],
) -> None:
self.id = lora_model_id
assert (lora_model_id >
0), f"a valid lora id should be greater than 0, got {self.id}"
self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras
@property
def extra_vocab_size(self) -> int:
return max(lora.extra_vocab_size
for lora in self.loras.values()) if self.loras else 0
def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
"""Get LoRA for a given module by name"""
return self.loras.get(module_name, None)
# (yard1): TODO see if we can derive target_embedding_padding automatically
@classmethod
def from_lora_tensors(
cls,
lora_model_id: int,
rank: int,
lora_alpha: int,
tensors: Dict[str, torch.Tensor],
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
embeddings: Optional[Dict[str, torch.Tensor]] = None,
target_embedding_padding: Optional[int] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and not in_wsl()
loras: Dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
if module_name not in loras:
lora_embeddings_tensor = None
if embeddings:
embeddings_module = next(
(k for k in EMBEDDING_MODULES if k in module_name),
None)
if embeddings_module:
lora_embeddings_tensor = embeddings[
EMBEDDING_MODULES[embeddings_module]].to(
device=device, dtype=dtype)
if pin_memory:
lora_embeddings_tensor = (
lora_embeddings_tensor.pin_memory())
loras[module_name] = LoRALayerWeights(module_name, rank,
lora_alpha, None, None,
lora_embeddings_tensor)
if is_lora_a:
loras[module_name].lora_a = tensor.to(device=device,
dtype=dtype).t()
if pin_memory:
loras[module_name].lora_a = loras[
module_name].lora_a.pin_memory()
else:
loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t()
if any(name in module_name
for name in EMBEDDING_PADDING_MODULES
) and target_embedding_padding is not None:
lora_b = loras[module_name].lora_b
assert target_embedding_padding >= lora_b.shape[1]
addition = target_embedding_padding - lora_b.shape[1]
loras[module_name].lora_b = torch.nn.functional.pad(
lora_b, (0, addition))
if pin_memory:
loras[module_name].lora_b = loras[
module_name].lora_b.pin_memory()
for lora in loras.values():
lora.optimize()
return cls(lora_model_id, rank, loras)
@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
lora_model_id: Optional[int] = None,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
target_embedding_padding: Optional[int] = None) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint."""
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir,
"new_embeddings.bin")
if os.path.isfile(lora_tensor_path):
tensors = safetensors.torch.load_file(lora_tensor_path)
elif os.path.isfile(lora_bin_file_path):
tensors = torch.load(lora_bin_file_path)
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")
embeddings = None
if os.path.isfile(new_embeddings_tensor_path):
embeddings = safetensors.torch.load_file(
new_embeddings_tensor_path)
elif os.path.isfile(new_embeddings_bin_file_path):
embeddings = torch.load(new_embeddings_bin_file_path)
with open(lora_config_path) as f:
config = json.load(f)
rank = config["r"]
lora_alpha = config["lora_alpha"]
return cls.from_lora_tensors(
lora_model_id=get_lora_id()
if lora_model_id is None else lora_model_id,
rank=rank,
lora_alpha=lora_alpha,
tensors=tensors,
device=device,
dtype=dtype,
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
)
class LoRAModelManager:
"""A manager that manages multiple LoRA-fine-tuned models."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
):
"""Create a LoRAModelManager and adapter for a given model.
Args:
model: the model to be adapted.
max_num_seqs: the maximum number of sequences model can run in a
single batch.
max_num_batched_tokens: the maximum number of tokens model can run
in a single batch.
vocab_size: the vocab size of the model.
lora_config: the LoRA configuration.
lora_target_modules: the target modules patterns to be adapted.
Support both single module name and a list of module names.
packed_modules_mapping: the mapping for packed modules. vLLM
packs some modules into one module, e.g., qkv_proj
is packed of q_proj, k_proj, and v_proj. These modules
have a single layer in the original model, but they are split
into multiple layers in the adapted model.
"""
self.lora_config = lora_config
self.max_num_seqs = max_num_seqs
assert self.capacity >= self.lora_slots
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.base_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.sampler_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.embeddings_indices = torch.empty(2,
self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.offsets = []
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
self.indices_len = [None] * 4
self.model: nn.Module = model
self.lora_target_modules: List[str] = ([
lora_target_modules
] if isinstance(lora_target_modules, str) else lora_target_modules)
self.lora_target_modules = copy.deepcopy(lora_target_modules)
self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping)
self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
self._registered_loras: Dict[int, LoRAModel] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_loras: Dict[int, None] = {}
self._last_mapping = None
self._create_lora_modules()
self.model.lora_manager = self
@property
def capacity(self) -> int:
return self.lora_config.max_cpu_loras
@property
def lora_slots(self) -> int:
return self.lora_config.max_loras
def __len__(self) -> int:
return len(self._registered_loras)
def activate_lora(
self,
lora_id: int,
) -> bool:
"""Move LoRA into a GPU buffer to be used in the forward pass."""
if lora_id in self._active_loras:
return False
first_free_slot = next(
((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
if lora_id is None), None)
if first_free_slot is None:
raise ValueError("No free lora slots")
index, _ = first_free_slot
self._active_loras[lora_id] = None
lora_model = self._registered_loras[lora_id]
logger.debug(
f"Activating LoRA. int id: {lora_model.id}, slot index: {index}")
self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items():
module_lora = lora_model.get_lora(module_name)
if module_lora:
module_lora.optimize()
module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
module_lora.embeddings_tensor)
else:
module.reset_lora(index)
return True
def _deactivate_lora(self, lora_id: int):
try:
index = self.lora_index_to_id.index(lora_id)
self.lora_index_to_id[index] = None
except ValueError:
pass
def deactivate_lora(self, lora_id: int) -> bool:
"""Remove a LoRA from a GPU buffer."""
if lora_id in self._active_loras:
self._deactivate_lora(lora_id)
self._active_loras.pop(lora_id)
return True
return False
def _add_lora(self, lora: LoRAModel) -> bool:
self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora
def add_lora(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager CPU cache."""
if lora.id not in self._registered_loras:
if len(self._registered_loras) >= self.capacity:
raise RuntimeError("No free LoRA slots.")
self._add_lora(lora)
return True
return False
def remove_lora(self, lora_id: int) -> bool:
"""Remove a LoRAModel from the manager CPU cache."""
# TODO: should we check active lora?
self.deactivate_lora(lora_id)
return bool(self._registered_loras.pop(lora_id, None))
# TODO see if this can be vectorized
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices,
indices_len) = convert_mapping(mapping, self.lora_index_to_id,
self.lora_slots + 1, self.vocab_size,
self.lora_config.lora_extra_vocab_size)
self.base_indices[:base_indices.shape[0]].copy_(base_indices)
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
sampler_indices_padded)
self.embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices)
# Maintain the reference
self.indices_len[:] = indices_len
def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
if self._last_mapping != lora_mapping:
self._set_lora_mapping(lora_mapping)
self._last_mapping = lora_mapping
def list_loras(self) -> Dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_loras)
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
return self._registered_loras.get(lora_id, None)
def remove_all_loras(self) -> bool:
"""Remove all LoRAModels from the manager."""
self._registered_loras.clear()
self.lora_index_to_id = [None] * self.lora_slots
self._active_loras.clear()
def _create_lora_modules(self):
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name):
continue
new_module = replace_submodule(
self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config,
self.model.config))
# (yard1): TODO make this more robust
if "lm_head" in module_name:
sampler_module = self.model.get_submodule("sampler")
new_module = replace_submodule(
self.model, "sampler",
from_layer_sampler(sampler_module, module, self.lora_slots,
self.lora_config, self.model.config))
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
new_module.set_mapping(self.base_indices, self.sampler_indices,
self.sampler_indices_padded,
self.embeddings_indices, self.indices_len)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA)
self.modules[module_name] = module
def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {})
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name) or not isinstance(
module, BaseLayerWithLoRA):
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
if parts[-1] in EMBEDDING_MODULES:
input_dim = (module.base_layer.org_vocab_size +
self.lora_config.lora_extra_vocab_size if
hasattr(module.base_layer, "org_vocab_size")
else module.base_layer.weight.shape[1])
output_dim = module.base_layer.embedding_dim if hasattr(
module.base_layer,
"embedding_dim") else module.base_layer.weight.shape[0]
embeddings_tensor_dim = (module.base_layer.embedding_dim if
hasattr(module.base_layer,
"embedding_dim") else
module.base_layer.weight.shape[1])
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
input_dim,
output_dim,
rank,
module.lora_a_stacked.dtype,
"cpu",
embeddings_tensor_dim=embeddings_tensor_dim)
else:
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.lora_a_stacked.shape[-1],
module.lora_b_stacked.shape[-2],
rank,
module.lora_a_stacked.dtype,
"cpu",
)
lora.optimize()
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
subloras = []
for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r,
module.lora_a_stacked[i].shape[-1],
module.lora_b_stacked[i].shape[-2],
rank,
module.lora_a_stacked[i].dtype,
"cpu",
)
lora.optimize()
subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora
return model
def _match_target_modules(self, module_name: str):
return any(
re.match(
r".*\.{target_module}$".format(target_module=target_module),
module_name) or target_module == module_name
for target_module in self.lora_target_modules)
def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")
module_name = parts[-1]
replacements = self.packed_modules_mapping.get(module_name)
if not replacements:
return
prefix = ".".join(parts[:-1])
self.packed_modules[module_full_name] = [
prefix + "." + r if prefix else r for r in replacements
]
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items():
replacement_loras = []
has_replacement = False
for r in new_module_names:
lora = lora_model.get_lora(r)
replacement_loras.append(lora)
if lora:
has_replacement = True
if not has_replacement:
continue
for i in range(len(replacement_loras)):
if replacement_loras[i]:
continue
replacement_loras[i] = None
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras)
class LoRALRUCache(LRUCache):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
None]):
super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: Hashable, value: Any):
logger.debug(f"Removing LoRA. int id: {key}")
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)
class LRUCacheLoRAModelManager(LoRAModelManager):
"""A model manager that manages multiple LoRAs with LRU cache."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
):
super().__init__(model, max_num_seqs, max_num_batched_tokens,
vocab_size, lora_config, lora_target_modules,
packed_modules_mapping)
self._registered_loras: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_lora)
self._active_loras: LoRALRUCache = LoRALRUCache(
self.lora_slots, self._deactivate_lora)
def list_loras(self) -> Dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_loras.cache)
def add_lora(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager."""
if lora.id not in self._registered_loras:
self._add_lora(lora)
was_added = True
else:
# We always touch to update the LRU cache order
self._registered_loras.touch(lora.id)
was_added = False
return was_added
def activate_lora(
self,
lora_id: int,
) -> bool:
if lora_id not in self._active_loras and len(
self._active_loras) >= self.lora_slots:
self._active_loras.remove_oldest()
result = super().activate_lora(lora_id)
# We always touch to update the LRU cache order
self._active_loras.touch(lora_id)
return result
def remove_oldest_lora(self) -> bool:
if len(self._registered_loras) > 0:
self._registered_loras.remove_oldest()
return True
return False
def create_lora_manager(
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
if not getattr(model, "supports_lora", False):
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
lora_manager = lora_manager_cls(
model=model,
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
vocab_size=vocab_size,
lora_config=lora_config,
lora_target_modules=target_modules,
**kwargs)
return lora_manager
# Based on code from https://github.com/punica-ai/punica
from typing import Optional
import torch
import_exc = None
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
import_exc = e
if import_exc is None:
def bgmv(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
):
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
matrices.
indicies: Shape: `[B]`. Indices of the weight matrices.
layer_idx: Layer index of the weight matrices.
scale: Scaling factor.
"""
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
def add_lora(y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
*,
buffer: Optional[torch.Tensor] = None):
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
LoRA A matrices.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
"""
r = wb_t_all.size(-1)
if buffer is None:
# We set the buffer to be float32 by default to avoid
# numerical innacuracies that would otherwise happen
# due to downcasting.
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx,
1.0)
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
scale)
def add_lora_slice(y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
y_offset: int,
y_slice_size: int,
*,
buffer: Optional[torch.Tensor] = None):
"""
Same as `add_lora` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
LoRA A matrices.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
r = wb_t_all.size(-1)
if buffer is None:
# We set the buffer to be float32 by default to avoid
# numerical inaccuracies that would otherwise happen
# due to downcasting.
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
punica_kernels.dispatch_bgmv_low_level(
buffer,
x,
wa_t_all,
indicies,
layer_idx,
1.0,
x.size(1),
buffer.size(1),
0,
)
punica_kernels.dispatch_bgmv_low_level(
y,
buffer,
wb_t_all,
indicies,
layer_idx,
scale,
buffer.size(1),
y_slice_size,
y_offset,
)
else:
def _raise_exc(
*args, # pylint: disable=unused-argument
**kwargs # pylint: disable=unused-argument
):
if torch.cuda.get_device_capability() < (8, 0):
raise ImportError("punica LoRA kernels require compute "
"capability>=8.0") from import_exc
else:
raise ImportError(
"punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
"was set.") from import_exc
bgmv = _raise_exc
add_lora = _raise_exc
add_lora_slice = _raise_exc
__all__ = [
"bgmv",
"add_lora",
"add_lora_slice",
]
from dataclasses import dataclass
@dataclass
class LoRARequest:
"""
Request for a LoRA adapter.
Note that this class should be be used internally. For online
serving, it is recommended to not allow users to use this class but
instead provide another layer of abstraction to prevent users from
accessing unauthorized LoRA adapters.
lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM.
"""
lora_name: str
lora_int_id: int
lora_local_path: str
def __post_init__(self):
if self.lora_int_id < 1:
raise ValueError(
f"lora_int_id must be > 0, got {self.lora_int_id}")
def __eq__(self, value: object) -> bool:
return isinstance(
value, LoRARequest) and self.lora_int_id == value.lora_int_id
def __hash__(self) -> int:
return self.lora_int_id
import logging
from typing import Tuple
from torch import nn
logger = logging.getLogger(__name__)
def replace_submodule(model: nn.Module, module_name: str,
new_module: nn.Module) -> nn.Module:
"""Replace a submodule in a model with a new module."""
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
target_name = module_name.split(".")[-1]
setattr(parent, target_name, new_module)
return new_module
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
"""Parse the name of lora weights.
args:
name: the name of the fine-tuned LoRA, e.g.
base_model.model.dense1.weight
return:
Tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b.
"""
parts = name.split(".")
assert parts[0] == "base_model"
assert parts[1] == "model"
if parts[-1] == "weight":
assert parts[-2] == "lora_A" or parts[-2] == "lora_B"
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
raise ValueError(f"{name} is unsupported format")
import logging
from abc import ABC, abstractmethod, abstractproperty
from typing import Any, List, Optional, Set, Type, Union
import torch
from vllm.lora.models import (TARGET_MODULES_QKV, LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.request import LoRARequest
from vllm.lora.layers import LoRAMapping
from vllm.config import LoRAConfig
logger = logging.getLogger(__name__)
class WorkerLoRAManager(ABC):
"""Abstract class for managing LoRA models on the worker side."""
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
vocab_size: int, lora_config: LoRAConfig,
device: torch.device):
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self.vocab_size = vocab_size
self.device = device
self.lora_config = lora_config
@abstractproperty
def is_enabled(self) -> bool:
...
@abstractmethod
def create_lora_manager(
self,
model: torch.nn.Module,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
) -> Any:
...
@abstractmethod
def set_active_loras(self, lora_requests: List[LoRARequest],
lora_mapping: LoRAMapping) -> None:
...
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
...
@abstractmethod
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
...
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
...
@abstractmethod
def remove_all_loras(self) -> bool:
...
@abstractmethod
def list_loras(self) -> Set[int]:
...
class WorkerLoRAManager(WorkerLoRAManager):
"""WorkerLoRAManager that manages LoRA models on the worker side.
Every request, the requested LoRAs will be loaded (unless they are already
loaded), and every other LoRA will be unloaded."""
_lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager
def __init__(
self,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
lora_model_cls: Type[LoRAModel] = LoRAModel,
):
self._lora_manager: Optional[LoRAModelManager] = None
self._lora_model_cls = lora_model_cls
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
lora_config, device)
@property
def is_enabled(self) -> bool:
return True
def create_lora_manager(
self,
model: torch.nn.Module,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
) -> Any:
lora_manager = create_lora_manager(
model,
max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens,
target_modules=target_modules,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
lora_manager_cls=self._lora_manager_cls,
)
self._lora_manager: LoRAModelManager = lora_manager
return lora_manager.model
def set_active_loras(self, lora_requests: List[LoRARequest],
lora_mapping: LoRAMapping) -> None:
self._apply_loras(lora_requests)
self._lora_manager.set_lora_mapping(lora_mapping)
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
loras_that_exist = self.list_loras()
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
}
if len(loras_map) > self._lora_manager.lora_slots:
raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots "
f"({self._lora_manager.lora_slots}).")
new_loras = set(loras_map)
loras_to_add = new_loras - loras_that_exist
loras_to_remove = loras_that_exist - new_loras
for lora_id in loras_to_remove:
self.remove_lora(lora_id)
for lora_id in loras_to_add:
self.add_lora(loras_map[lora_id])
def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
try:
lora = self._lora_model_cls.from_local_checkpoint(
lora_request.lora_local_path,
lora_model_id=lora_request.lora_int_id,
device="cpu",
dtype=self.lora_config.lora_dtype,
target_embedding_padding=self.vocab_size +
self.lora_config.lora_extra_vocab_size,
)
except Exception as e:
raise RuntimeError(
f"Loading lora {lora_request.lora_local_path} failed") from e
if lora.rank > self.lora_config.max_lora_rank:
raise ValueError(
f"LoRA rank {lora.rank} is greater than max_lora_rank "
f"{self.lora_config.max_lora_rank}.")
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
raise ValueError(
f"LoRA added vocab size {lora.extra_vocab_size} is greater than "
f"lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}."
)
return lora
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_loras():
return False
return self._lora_manager.add_lora(
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
rank))
def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id in self.list_loras():
return False
lora = self._load_lora(lora_request)
loaded = self._lora_manager.add_lora(lora)
self._lora_manager.activate_lora(lora.id)
return loaded
def remove_lora(self, lora_id: int) -> bool:
return self._lora_manager.remove_lora(lora_id)
def remove_all_loras(self) -> bool:
self._lora_manager.remove_all_loras()
def list_loras(self) -> Set[int]:
return set(self._lora_manager.list_loras())
class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
"""WorkerLoRAManager that manages LoRA models on the worker side.
Uses an LRU Cache. Every request, the requested LoRAs will be loaded
(unless they are already loaded) and least recently used LoRAs will
be unloaded if the cache is above capacity."""
_lora_manager_cls: Type[
LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
def create_lora_manager(
self,
model: torch.nn.Module,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
) -> Any:
lora_manager = create_lora_manager(
model,
target_modules=target_modules,
lora_manager_cls=self._lora_manager_cls,
max_num_seqs=self.max_num_seqs,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
max_num_batched_tokens=self.max_num_batched_tokens,
)
self._lora_manager: LRUCacheLoRAModelManager = lora_manager
return lora_manager.model
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
}
if len(loras_map) > self._lora_manager.lora_slots:
raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots "
f"({self._lora_manager.lora_slots}).")
for lora in loras_map.values():
self.add_lora(lora)
def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id not in self.list_loras():
# Remove before we load the new lora to save memory
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
self._lora_manager.remove_oldest_lora()
lora = self._load_lora(lora_request)
loaded = self._lora_manager.add_lora(lora)
else:
# If the lora is already loaded, just touch it to
# update its position in the caches
loaded = self._lora_manager.get_lora(lora_request.lora_int_id)
self._lora_manager.activate_lora(lora_request.lora_int_id)
return loaded
...@@ -12,23 +12,32 @@ class InputMetadata: ...@@ -12,23 +12,32 @@ class InputMetadata:
max_context_len: The maximum context length. max_context_len: The maximum context length.
context_lens: the length of attention context for each sequence. context_lens: the length of attention context for each sequence.
block_tables: The block tables. (Seq id -> list of physical block) block_tables: The block tables. (Seq id -> list of physical block)
kv_cache_dtype: Data type to store kv cache.
""" """
def __init__( def __init__(
self, self,
is_prompt: bool, is_prompt: bool,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
prompt_lens: Optional[torch.Tensor],
max_seq_len: Optional[int],
start_loc: Optional[torch.Tensor],
max_context_len: Optional[int], max_context_len: Optional[int],
context_lens: Optional[torch.Tensor], context_lens: Optional[torch.Tensor],
block_tables: Optional[torch.Tensor], block_tables: Optional[torch.Tensor],
use_cuda_graph: bool, use_cuda_graph: bool,
kv_cache_dtype: str,
) -> None: ) -> None:
self.is_prompt = is_prompt self.is_prompt = is_prompt
self.prompt_lens = prompt_lens
self.max_seq_len = max_seq_len
self.start_loc = start_loc
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.slot_mapping = slot_mapping self.slot_mapping = slot_mapping
self.context_lens = context_lens self.context_lens = context_lens
self.block_tables = block_tables self.block_tables = block_tables
self.use_cuda_graph = use_cuda_graph self.use_cuda_graph = use_cuda_graph
self.kv_cache_dtype = kv_cache_dtype
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
# FIXME(woosuk): This is a hack. # FIXME(woosuk): This is a hack.
...@@ -41,4 +50,5 @@ class InputMetadata: ...@@ -41,4 +50,5 @@ class InputMetadata:
f"slot_mapping={self.slot_mapping}, " f"slot_mapping={self.slot_mapping}, "
f"context_lens={self.context_lens}, " f"context_lens={self.context_lens}, "
f"block_tables={self.block_tables}, " f"block_tables={self.block_tables}, "
f"use_cuda_graph={self.use_cuda_graph})") f"use_cuda_graph={self.use_cuda_graph}, "
f"kv_cache_dtype={self.kv_cache_dtype})")
...@@ -10,6 +10,8 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, ...@@ -10,6 +10,8 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from vllm._C import ops from vllm._C import ops
from vllm._C import cache_ops from vllm._C import cache_ops
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
context_attention_fwd)
from vllm.utils import is_hip from vllm.utils import is_hip
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
...@@ -96,6 +98,7 @@ class PagedAttention(nn.Module): ...@@ -96,6 +98,7 @@ class PagedAttention(nn.Module):
key_cache, key_cache,
value_cache, value_cache,
input_metadata.slot_mapping.flatten(), input_metadata.slot_mapping.flatten(),
input_metadata.kv_cache_dtype,
) )
if input_metadata.is_prompt: if input_metadata.is_prompt:
...@@ -115,61 +118,76 @@ class PagedAttention(nn.Module): ...@@ -115,61 +118,76 @@ class PagedAttention(nn.Module):
self.num_kv_heads, self.num_kv_heads,
self.num_queries_per_kv, self.num_queries_per_kv,
value.shape[-1]) value.shape[-1])
# normal attention
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)
# Set attention bias if not provided. This typically happens at the # TODO(woosuk): Too many view operations. Let's try to reduce
# very attention layer of every iteration. # them in the future for code readability.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if self.alibi_slopes is None: if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens( query = query.unsqueeze(0)
[seq_len] * batch_size) key = key.unsqueeze(0)
if self.sliding_window is not None: value = value.unsqueeze(0)
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
else: else:
input_metadata.attn_bias = _make_alibi_bias( query = query.unflatten(0, (batch_size, seq_len))
self.alibi_slopes, self.num_kv_heads, batch_size, key = key.unflatten(0, (batch_size, seq_len))
seq_len, query.dtype) value = value.unflatten(0, (batch_size, seq_len))
# TODO(woosuk): Too many view operations. Let's try to reduce them out = xops.memory_efficient_attention_forward(
# in the future for code readability. query,
if self.alibi_slopes is None: key,
query = query.unsqueeze(0) value,
key = key.unsqueeze(0) attn_bias=input_metadata.attn_bias,
value = value.unsqueeze(0) p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = out.view_as(query)
else: else:
query = query.unflatten(0, (batch_size, seq_len)) # prefix-enabled attention
key = key.unflatten(0, (batch_size, seq_len)) output = torch.empty_like(query)
value = value.unflatten(0, (batch_size, seq_len)) context_attention_fwd(
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = out.view_as(query)
else:
# Decoding run.
if key_cache is not None and value_cache is not None:
output = _paged_attention(
query, query,
key,
value,
output,
key_cache, key_cache,
value_cache, value_cache,
input_metadata, input_metadata.block_tables, # [BS, max_block_per_request]
self.num_kv_heads, input_metadata.start_loc,
self.scale, input_metadata.prompt_lens,
self.alibi_slopes, input_metadata.context_lens,
input_metadata.max_seq_len,
getattr(self, "alibi_slopes", None),
) )
else:
# This happens during the initial memory profiling run for else:
# CUDA graphs. # Decoding run.
output = torch.zeros_like(query) output = _paged_attention(
query,
key_cache,
value_cache,
input_metadata,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)
# Reshape the output tensor. # Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size) return output.view(batch_size, seq_len, hidden_size)
...@@ -248,6 +266,7 @@ def _paged_attention( ...@@ -248,6 +266,7 @@ def _paged_attention(
block_size, block_size,
input_metadata.max_context_len, input_metadata.max_context_len,
alibi_slopes, alibi_slopes,
input_metadata.kv_cache_dtype,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
...@@ -278,5 +297,6 @@ def _paged_attention( ...@@ -278,5 +297,6 @@ def _paged_attention(
block_size, block_size,
input_metadata.max_context_len, input_metadata.max_context_len,
alibi_slopes, alibi_slopes,
input_metadata.kv_cache_dtype,
) )
return output return output
"""Fused MoE kernel."""
import torch
import triton
import triton.language as tl
from vllm._C import ops
@triton.jit
def fused_moe_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated,
and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to.
- expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A.
This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids`
by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor):
"""
Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations.
"""
sorted_ids = torch.empty(
(topk_ids.numel() + num_experts * (block_size - 1), ),
dtype=torch.int32,
device=topk_ids.device)
expert_ids = torch.empty((topk_ids.numel() + num_experts, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int, config: dict):
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
fused_moe_kernel[grid](
A,
B,
C,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,
**config,
)
def fused_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=False):
"""
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The weights for the top-k selected experts.
- topk_ids (torch.Tensor): The indices of the top-k selected experts.
- inplace (bool): If True, perform the operation in-place. Defaults to False.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
M, _ = hidden_states.shape
E, N, _ = w1.shape
config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}
if topk_ids.numel() <= w1.shape[0]:
config = {
'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 1
}
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
topk_weights, topk_ids, sorted_token_ids,
expert_ids, num_tokens_post_padded, False,
topk_ids.shape[1], config)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,
topk_weights, topk_ids, sorted_token_ids,
expert_ids, num_tokens_post_padded, True, 1,
config)
if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
...@@ -423,7 +423,10 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -423,7 +423,10 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
shard_id = tp_rank // self.num_kv_head_replicas if loaded_shard_id == "q":
shard_id = tp_rank
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
......
...@@ -153,7 +153,16 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -153,7 +153,16 @@ class AWQLinearMethod(LinearMethodBase):
pack_factor = self.quant_config.pack_factor pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
out = torch.matmul(reshaped_x, out)
else:
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out.reshape(out_shape) return out.reshape(out_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment