Commit ec5e299c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.3' into v0.7.3-dev

parents 47bd229c ed6e9075
...@@ -28,12 +28,15 @@ from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser, ...@@ -28,12 +28,15 @@ from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
truncate_tool_call_ids)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -150,11 +153,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -150,11 +153,12 @@ class OpenAIServingChat(OpenAIServing):
return self.create_error_response( return self.create_error_response(
"tool_choice = \"required\" is not supported!") "tool_choice = \"required\" is not supported!")
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
if isinstance(tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) maybe_serialize_tool_calls(request)
truncate_tool_call_ids(request)
if (request.tool_choice == "auto" and if (request.tool_choice == "auto" and
not (self.enable_auto_tools and tool_parser is not None) not (self.enable_auto_tools and tool_parser is not None)
...@@ -745,11 +749,13 @@ class OpenAIServingChat(OpenAIServing): ...@@ -745,11 +749,13 @@ class OpenAIServingChat(OpenAIServing):
elif request.tool_choice and type( elif request.tool_choice and type(
request.tool_choice) is ChatCompletionNamedToolChoiceParam: request.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_call_class = MistralToolCall if isinstance(
tokenizer, MistralTokenizer) else ToolCall
message = ChatMessage( message = ChatMessage(
role=role, role=role,
content="", content="",
tool_calls=[ tool_calls=[
ToolCall(function=FunctionCall( tool_call_class(function=FunctionCall(
name=request.tool_choice.function.name, name=request.tool_choice.function.name,
arguments=output.text)) arguments=output.text))
]) ])
......
...@@ -31,7 +31,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -31,7 +31,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ErrorResponse, RerankRequest, ErrorResponse, RerankRequest,
ScoreRequest, ScoreRequest,
TokenizeChatRequest, TokenizeChatRequest,
TokenizeCompletionRequest) TokenizeCompletionRequest,
TranscriptionRequest)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers import ToolParser
# yapf: enable # yapf: enable
...@@ -57,7 +58,8 @@ CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, ...@@ -57,7 +58,8 @@ CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest] TokenizeChatRequest]
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest] AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest,
TranscriptionRequest]
class TextTokensPrompt(TypedDict): class TextTokensPrompt(TypedDict):
...@@ -400,8 +402,7 @@ class OpenAIServing: ...@@ -400,8 +402,7 @@ class OpenAIServing:
_chat_template_kwargs.update(chat_template_kwargs or {}) _chat_template_kwargs.update(chat_template_kwargs or {})
request_prompt: Union[str, List[int]] request_prompt: Union[str, List[int]]
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer) if isinstance(tokenizer, MistralTokenizer):
if is_mistral_tokenizer:
request_prompt = apply_mistral_chat_template( request_prompt = apply_mistral_chat_template(
tokenizer, tokenizer,
messages=messages, messages=messages,
...@@ -450,6 +451,8 @@ class OpenAIServing: ...@@ -450,6 +451,8 @@ class OpenAIServing:
prompt_token_ids=prompt_inputs["prompt_token_ids"]) prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None: if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data engine_prompt["multi_modal_data"] = mm_data
if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
return conversation, [request_prompt], [engine_prompt] return conversation, [request_prompt], [engine_prompt]
......
...@@ -121,7 +121,7 @@ class OpenAIServingScores(OpenAIServing): ...@@ -121,7 +121,7 @@ class OpenAIServingScores(OpenAIServing):
tokenize_async = make_async(tokenizer.__call__, tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor) executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=q, prompt_inputs = await tokenize_async(q,
text_pair=t, text_pair=t,
**tokenization_kwargs) **tokenization_kwargs)
......
# SPDX-License-Identifier: Apache-2.0
import asyncio
import io
from typing import AsyncGenerator, Optional, Union, cast
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ErrorResponse,
RequestResponseMetadata,
TranscriptionRequest,
TranscriptionResponse,
TranscriptionResponseVerbose)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.utils import PlaceholderModule
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
logger = init_logger(__name__)
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages
# TODO these configs should live somewhere with the model so we can support
# additional ones
ISO639_1_SUPPORTED_LANGS = {
"af": "Afrikaans",
"ar": "Arabic",
"hy": "Armenian",
"az": "Azerbaijani",
"be": "Belarusian",
"bs": "Bosnian",
"bg": "Bulgarian",
"ca": "Catalan",
"zh": "Chinese",
"hr": "Croatian",
"cs": "Czech",
"da": "Danish",
"nl": "Dutch",
"en": "English",
"et": "Estonian",
"fi": "Finnish",
"fr": "French",
"gl": "Galician",
"de": "German",
"el": "Greek",
"he": "Hebrew",
"hi": "Hindi",
"hu": "Hungarian",
"is": "Icelandic",
"id": "Indonesian",
"it": "Italian",
"ja": "Japanese",
"kn": "Kannada",
"kk": "Kazakh",
"ko": "Korean",
"lv": "Latvian",
"lt": "Lithuanian",
"mk": "Macedonian",
"ms": "Malay",
"mr": "Marathi",
"mi": "Maori",
"ne": "Nepali",
"no": "Norwegian",
"fa": "Persian",
"pl": "Polish",
"pt": "Portuguese",
"ro": "Romanian",
"ru": "Russian",
"sr": "Serbian",
"sk": "Slovak",
"sl": "Slovenian",
"es": "Spanish",
"sw": "Swahili",
"sv": "Swedish",
"tl": "Tagalog",
"ta": "Tamil",
"th": "Thai",
"tr": "Turkish",
"uk": "Ukrainian",
"ur": "Urdu",
"vi": "Vietnamese",
"cy": "Welsh"
}
ISO639_1_OTHER_LANGS = {
"lo": "Lao",
"jw": "Javanese",
"tk": "Turkmen",
"yi": "Yiddish",
"so": "Somali",
"bn": "Bengali",
"nn": "Norwegian Nynorsk",
"si": "Sinhala",
"yo": "Yoruba",
"sa": "Sanskrit",
"mi": "Māori",
"fo": "Faroese", # codespell:ignore
"mt": "Maltese",
"tg": "Tajik",
"mg": "Malagasy",
"haw": "Hawaiian",
"km": "Khmer",
"br": "Breton",
"ps": "Pashto",
"ln": "Lingala",
"la": "Latin",
"ml": "Malayalam",
"sq": "Albanian",
"su": "Sundanese",
"eu": "Basque",
"ka": "Georgian",
"uz": "Uzbek",
"sn": "Shona",
"ht": "Haitian",
"as": "Assamese",
"mn": "Mongolian",
"te": "Telugu",
"pa": "Panjabi",
"tt": "Tatar",
"gu": "Gujarati",
"oc": "Occitan",
"ha": "Hausa",
"ba": "Bashkir",
"my": "Burmese",
"sd": "Sindhi",
"am": "Amharic",
"lb": "Luxembourgish",
"bo": "Tibetan"
}
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# TODO configurable
MAX_AUDIO_CLIP_FILESIZE_MB = 25
# TODO get from processor.feature_extractor.chunk_length
MAX_AUDIO_CLIP_DURATION_S = 30
class OpenAIServingTranscription(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
logger.info(
"Overwriting default completion sampling param with: %s",
diff_sampling_param)
async def _preprocess_transcription(
self,
request: TranscriptionRequest,
audio_data: bytes,
) -> PromptType:
# Validate request
# TODO language should be optional and can be guessed.
# For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
lang_token = f"<|{request.language}|>" if request.language else "<|en|>"
if request.language:
if request.language in ISO639_1_SUPPORTED_LANGS:
pass
elif request.language in ISO639_1_OTHER_LANGS:
logger.warning(
"The selected language %s has limited accuracy with"
" reported WER>=0.5. Results may be less accurate "
"for this choice.", request.language)
else:
raise ValueError(
f"Unsupported language: {request.language}."
"Language should be one of:" +
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
f"or {list(ISO639_1_OTHER_LANGS.values())}")
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
raise ValueError("Maximum file size exceeded.")
with io.BytesIO(audio_data) as bytes_:
y, sr = librosa.load(bytes_)
if librosa.get_duration(y=y, sr=sr) > MAX_AUDIO_CLIP_DURATION_S:
raise ValueError(
f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s) "
"exceeded.")
prompt = {
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": (y, sr),
},
},
"decoder_prompt":
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
}
return cast(PromptType, prompt)
# TODO (varun) : Make verbose response work !
async def create_transcription(
self, audio_data: bytes, request: TranscriptionRequest,
raw_request: Request
) -> Union[TranscriptionResponse, TranscriptionResponseVerbose,
ErrorResponse]:
"""Transcription API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranscription
for the API specification. This API mimics the OpenAI transcription API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
if request.response_format not in ['text', 'json']:
return self.create_error_response(
"Currently only support response_format `text` or `json`")
# TODO cmpl->transcription?
request_id = f"cmpl-{self._base_request_id(raw_request)}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
if lora_request:
return self.create_error_response(
"Currently do not support LoRA for Transcription.")
if prompt_adapter_request:
return self.create_error_response(
"Currently do not support PromptAdapter for Transcription."
)
prompt = await self._preprocess_transcription(
request=request,
audio_data=audio_data,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None
try:
# TODO(rob): subtract len of tokenized prompt.
default_max_tokens = self.model_config.max_model_len
default_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
default_max_tokens, default_params)
self._log_inputs(
request_id,
prompt['decoder_prompt'], # type: ignore
params=sampling_params,
lora_request=None,
prompt_adapter_request=None)
result_generator = self.engine_client.generate(
prompt,
sampling_params,
request_id,
)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
# TODO(rob): figure out a way to pipe streaming in.
# Non-streaming response.
try:
assert result_generator is not None
async for op in result_generator:
result = op
return TranscriptionResponse(text=result.outputs[0].text)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
...@@ -33,7 +33,7 @@ class MistralToolCall(ToolCall): ...@@ -33,7 +33,7 @@ class MistralToolCall(ToolCall):
@staticmethod @staticmethod
def generate_random_id(): def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9. # Mistral Tool Call Ids must be alphanumeric with a length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9)) return "".join(choices(ALPHANUMERIC, k=9))
......
...@@ -60,6 +60,7 @@ if TYPE_CHECKING: ...@@ -60,6 +60,7 @@ if TYPE_CHECKING:
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_VIDEO_FETCH_TIMEOUT: int = 30
VLLM_AUDIO_FETCH_TIMEOUT: int = 10 VLLM_AUDIO_FETCH_TIMEOUT: int = 10
VLLM_MM_INPUT_CACHE_SIZE: int = 256
VLLM_TARGET_DEVICE: str = "cuda" VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None NVCC_THREADS: Optional[str] = None
...@@ -93,6 +94,8 @@ if TYPE_CHECKING: ...@@ -93,6 +94,8 @@ if TYPE_CHECKING:
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_RAY_BUNDLE_INDICES: str = ""
VLLM_CUDART_SO_PATH: Optional[str] = None
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
def get_default_cache_root(): def get_default_cache_root():
...@@ -431,15 +434,21 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -431,15 +434,21 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
# Timeout for fetching videos when serving multimodal models # Timeout for fetching videos when serving multimodal models
# Default is 15 seconds # Default is 30 seconds
"VLLM_VIDEO_FETCH_TIMEOUT": "VLLM_VIDEO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_VIDEO_FETCH_TIMEOUT", "15")), lambda: int(os.getenv("VLLM_VIDEO_FETCH_TIMEOUT", "30")),
# Timeout for fetching audio when serving multimodal models # Timeout for fetching audio when serving multimodal models
# Default is 10 seconds # Default is 10 seconds
"VLLM_AUDIO_FETCH_TIMEOUT": "VLLM_AUDIO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
# Cache size for multimodal feature/input cache for multimodal models
# in unit of number of multimodal data items (e.g. image, video, audio).
# Default is 256 multimodal data items.
"VLLM_MM_INPUT_CACHE_SIZE":
lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_SIZE", "256")),
# Path to the XLA persistent cache directory. # Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs. # Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH": "VLLM_XLA_CACHE_PATH":
...@@ -608,6 +617,18 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -608,6 +617,18 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# models the alignment is already naturally aligned to 256 bytes. # models the alignment is already naturally aligned to 256 bytes.
"VLLM_CUDA_MEM_ALIGN_KV_CACHE": "VLLM_CUDA_MEM_ALIGN_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))), lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))),
# In some system, find_loaded_library() may not work. So we allow users to
# specify the path through environment variable VLLM_CUDART_SO_PATH.
"VLLM_CUDART_SO_PATH":
lambda: os.getenv("VLLM_CUDART_SO_PATH", None),
# Contiguous cache fetching to avoid using costly gather operation on
# Gaudi3. This is only applicable to HPU contiguous cache. If set to true,
# contiguous cache fetch will be used.
"VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH":
lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in
("1", "true"),
} }
# end-env-vars-definition # end-env-vars-definition
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
Union) Union)
...@@ -8,11 +9,11 @@ from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, ...@@ -8,11 +9,11 @@ from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
import torch.nn as nn import torch.nn as nn
from typing_extensions import TypeVar from typing_extensions import TypeVar
import vllm.platforms
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async from vllm.utils import make_async
...@@ -108,8 +109,8 @@ class ExecutorBase(ABC): ...@@ -108,8 +109,8 @@ class ExecutorBase(ABC):
""" """
# NOTE: This is logged in the executor because there can be >1 workers. # NOTE: This is logged in the executor because there can be >1 workers.
logger.info("# %s blocks: %d, # CPU blocks: %d", logger.info("# %s blocks: %d, # CPU blocks: %d",
current_platform.dispatch_key, num_gpu_blocks, vllm.platforms.current_platform.device_name,
num_cpu_blocks) num_gpu_blocks, num_cpu_blocks)
max_concurrency = (num_gpu_blocks * self.cache_config.block_size / max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
self.model_config.max_model_len) self.model_config.max_model_len)
logger.info("Maximum concurrency for %s tokens per request: %.2fx", logger.info("Maximum concurrency for %s tokens per request: %.2fx",
...@@ -200,15 +201,23 @@ class ExecutorBase(ABC): ...@@ -200,15 +201,23 @@ class ExecutorBase(ABC):
if self.is_sleeping: if self.is_sleeping:
logger.warning("Executor is already sleeping.") logger.warning("Executor is already sleeping.")
return return
time_before_sleep = time.perf_counter()
self.collective_rpc("sleep", kwargs=dict(level=level)) self.collective_rpc("sleep", kwargs=dict(level=level))
time_after_sleep = time.perf_counter()
self.is_sleeping = True self.is_sleeping = True
logger.info("It took %.6f seconds to fall asleep.",
time_after_sleep - time_before_sleep)
def wake_up(self): def wake_up(self):
if not self.is_sleeping: if not self.is_sleeping:
logger.warning("Executor is not sleeping.") logger.warning("Executor is not sleeping.")
return return
time_before_wakeup = time.perf_counter()
self.collective_rpc("wake_up") self.collective_rpc("wake_up")
time_after_wakeup = time.perf_counter()
self.is_sleeping = False self.is_sleeping = False
logger.info("It took %.6f seconds to wake up.",
time_after_wakeup - time_before_wakeup)
def save_sharded_state( def save_sharded_state(
self, self,
......
...@@ -101,6 +101,10 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -101,6 +101,10 @@ class RayDistributedExecutor(DistributedExecutorBase):
self.driver_worker.execute_method) self.driver_worker.execute_method)
def shutdown(self) -> None: def shutdown(self) -> None:
logger.info(
"Shutting down Ray distributed executor. If you see error log "
"from logging.cc regarding SIGTERM received, please ignore because "
"this is the expected termination process in Ray.")
if hasattr(self, "forward_dag") and self.forward_dag is not None: if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown() self.forward_dag.teardown()
import ray import ray
......
...@@ -7,10 +7,10 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union ...@@ -7,10 +7,10 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import msgspec import msgspec
import vllm.platforms
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
...@@ -35,7 +35,7 @@ try: ...@@ -35,7 +35,7 @@ try:
class RayWorkerWrapper(WorkerWrapperBase): class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be """Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" lazily initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -54,10 +54,10 @@ try: ...@@ -54,10 +54,10 @@ try:
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
node_id = ray.get_runtime_context().get_node_id() node_id = ray.get_runtime_context().get_node_id()
device_key = current_platform.ray_device_key device_key = vllm.platforms.current_platform.ray_device_key
if not device_key: if not device_key:
raise RuntimeError("current platform %s does not support ray.", raise RuntimeError("current platform %s does not support ray.",
current_platform.device_name) vllm.platforms.current_platform.device_name)
gpu_ids = ray.get_runtime_context().get_accelerator_ids( gpu_ids = ray.get_runtime_context().get_accelerator_ids(
)[device_key] )[device_key]
return node_id, gpu_ids return node_id, gpu_ids
...@@ -118,7 +118,14 @@ try: ...@@ -118,7 +118,14 @@ try:
) -> "ModelRunnerOutput": ) -> "ModelRunnerOutput":
self.setup_device_if_necessary() self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized" assert self.worker is not None, "Worker is not initialized"
output = self.worker.model_runner.execute_model(scheduler_output) if isinstance(scheduler_output, tuple):
scheduler_output, intermediate_tensors = scheduler_output
else:
scheduler_output, intermediate_tensors = scheduler_output, None
output = self.worker.model_runner.execute_model(
scheduler_output, intermediate_tensors)
if isinstance(output, IntermediateTensors):
output = scheduler_output, output
return output return output
def override_env_vars(self, vars: Dict[str, str]): def override_env_vars(self, vars: Dict[str, str]):
......
...@@ -28,6 +28,11 @@ class UniProcExecutor(ExecutorBase): ...@@ -28,6 +28,11 @@ class UniProcExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
local_rank = 0 local_rank = 0
# set local rank as the device index if specified
device_info = self.vllm_config.device_config.device.__str__().split(
":")
if len(device_info) > 1:
local_rank = int(device_info[1])
rank = 0 rank = 0
kwargs = dict( kwargs = dict(
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
...@@ -101,7 +106,7 @@ class ExecutorWithExternalLauncher(UniProcExecutor): ...@@ -101,7 +106,7 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
# - MASTER_PORT # - MASTER_PORT
distributed_init_method = "env://" distributed_init_method = "env://"
rank = int(os.environ["RANK"]) rank = int(os.environ["RANK"])
local_rank = rank local_rank = int(os.environ["LOCAL_RANK"])
is_driver_worker = True is_driver_worker = True
kwargs = dict( kwargs = dict(
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
from typing import List, Mapping, Optional, Union from typing import List, Mapping, Optional, Tuple, Union, cast
from typing_extensions import assert_never from typing_extensions import assert_never
...@@ -9,7 +9,8 @@ from vllm.config import ModelConfig ...@@ -9,7 +9,8 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
...@@ -254,14 +255,18 @@ class InputPreprocessor: ...@@ -254,14 +255,18 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt, Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata. returning the corresponding token IDs and metadata.
""" """
tokenizer_group = self.get_tokenizer_group() # At the moment on model (PrithviGeoSpatialMAE) requires to be
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request) # initialized without a tokenizer while using also multi-modal
# input.
if not self.tokenizer:
tokenizer = None
else:
tokenizer_group = self.get_tokenizer_group()
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor( mm_processor = self.mm_registry.create_processor(
self.model_config, tokenizer) self.model_config, tokenizer)
if isinstance(prompt, list):
prompt = tokenizer.decode(prompt)
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
...@@ -275,9 +280,15 @@ class InputPreprocessor: ...@@ -275,9 +280,15 @@ class InputPreprocessor:
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
) -> MultiModalInputs: ) -> MultiModalInputs:
"""Async version of :meth:`_process_multimodal`.""" """Async version of :meth:`_process_multimodal`."""
tokenizer_group = self.get_tokenizer_group() # At the moment on model (PrithviGeoSpatialMAE) requires to be
tokenizer = await tokenizer_group.get_lora_tokenizer_async(lora_request # initialized without a tokenizer while using also multi-modal
) # input.
if not self.tokenizer:
tokenizer = None
else:
tokenizer_group = self.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async(
lora_request)
mm_processor = self.mm_registry.create_processor( mm_processor = self.mm_registry.create_processor(
self.model_config, tokenizer) self.model_config, tokenizer)
...@@ -485,6 +496,51 @@ class InputPreprocessor: ...@@ -485,6 +496,51 @@ class InputPreprocessor:
decoder=decoder_inputs, decoder=decoder_inputs,
) )
def _separate_enc_dec_inputs_from_mm_processor_outputs(
self,
inputs: SingletonInputs,
decoder_inputs_to_override: Optional[SingletonInputs] = None,
) -> Tuple[SingletonInputs, SingletonInputs]:
"""
For encoder/decoder models only:
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
"""
encoder_inputs: SingletonInputs
decoder_inputs: SingletonInputs
if inputs["type"] == "multimodal":
# Multimodal data inputs
assert ("encoder_prompt" in inputs
and "encoder_prompt_token_ids" in inputs)
inputs = cast(MultiModalEncDecInputs, inputs)
encoder_inputs = token_inputs(
prompt=inputs["encoder_prompt"],
prompt_token_ids=inputs["encoder_prompt_token_ids"],
)
if decoder_inputs_to_override is not None:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=decoder_inputs_to_override.get("prompt", ""),
prompt_token_ids=decoder_inputs_to_override[
"prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_placeholders=inputs["mm_placeholders"],
)
else:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=inputs["prompt"],
prompt_token_ids=inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_placeholders=inputs["mm_placeholders"],
)
elif inputs["type"] == "token":
# Text-only inputs
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
decoder_inputs = decoder_inputs_to_override or inputs
else:
assert_never(inputs) # type: ignore[arg-type]
return encoder_inputs, decoder_inputs
def _process_encoder_decoder_prompt( def _process_encoder_decoder_prompt(
self, self,
prompt: PromptType, prompt: PromptType,
...@@ -529,7 +585,6 @@ class InputPreprocessor: ...@@ -529,7 +585,6 @@ class InputPreprocessor:
prompt["encoder_prompt"], prompt["encoder_prompt"],
request_id=request_id, request_id=request_id,
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None decoder_inputs = None
else: else:
...@@ -537,13 +592,28 @@ class InputPreprocessor: ...@@ -537,13 +592,28 @@ class InputPreprocessor:
decoder_input, decoder_input,
request_id=request_id, request_id=request_id,
) )
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else: else:
encoder_inputs = self._prompt_to_llm_inputs( inputs = self._prompt_to_llm_inputs(
prompt, prompt,
request_id=request_id, request_id=request_id,
) )
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
inputs))
else:
encoder_inputs = inputs
decoder_inputs = None decoder_inputs = None
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
...@@ -573,13 +643,29 @@ class InputPreprocessor: ...@@ -573,13 +643,29 @@ class InputPreprocessor:
encoder_inputs, decoder_inputs = await asyncio.gather( encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task) encoder_task, decoder_task)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else: else:
encoder_inputs = await self._prompt_to_llm_inputs_async( inputs = await self._prompt_to_llm_inputs_async(
prompt, prompt,
request_id=request_id, request_id=request_id,
) )
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
inputs))
else:
encoder_inputs = inputs
decoder_inputs = None decoder_inputs = None
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
......
...@@ -11,8 +11,9 @@ from transformers import BatchFeature, PretrainedConfig, ProcessorMixin ...@@ -11,8 +11,9 @@ from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import TypeVar, assert_never from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import (AnyTokenizer,
cached_tokenizer_from_config)
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs) resolve_mm_processor_kwargs)
...@@ -27,19 +28,9 @@ if TYPE_CHECKING: ...@@ -27,19 +28,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) _T = TypeVar("_T")
P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin) _C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
class HashableDict(dict):
"""
A dictionary that can be hashed by lru_cache.
"""
# NOTE: pythonic dict is not hashable,
# we override on it directly for simplicity
def __hash__(self) -> int: # type: ignore[override]
return hash(frozenset(self.items()))
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -54,9 +45,9 @@ class InputContext: ...@@ -54,9 +45,9 @@ class InputContext:
def get_hf_config( def get_hf_config(
self, self,
typ: Union[type[C], tuple[type[C], ...]] = PretrainedConfig, typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig,
/, /,
) -> C: ) -> _C:
""" """
Get the HuggingFace configuration Get the HuggingFace configuration
(:class:`transformers.PretrainedConfig`) of the model, (:class:`transformers.PretrainedConfig`) of the model,
...@@ -94,10 +85,10 @@ class InputContext: ...@@ -94,10 +85,10 @@ class InputContext:
def get_hf_processor( def get_hf_processor(
self, self,
typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin, typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
/, /,
**kwargs: object, **kwargs: object,
) -> P: ) -> _P:
""" """
Get the HuggingFace processor Get the HuggingFace processor
(:class:`transformers.ProcessorMixin`) of the model, (:class:`transformers.ProcessorMixin`) of the model,
...@@ -106,33 +97,29 @@ class InputContext: ...@@ -106,33 +97,29 @@ class InputContext:
Raises: Raises:
TypeError: If the processor is not of the specified type. TypeError: If the processor is not of the specified type.
""" """
return cached_processor_from_config(
self.model_config,
processor_cls=typ,
**kwargs,
)
def init_processor(
self,
typ: type[_T],
/,
**kwargs: object,
) -> _T:
"""
Initialize a HuggingFace-like processor class, merging the
keyword arguments with those in the model's configuration.
"""
base_kwargs = self.model_config.mm_processor_kwargs base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None: if base_kwargs is None:
base_kwargs = {} base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs} merged_kwargs = {**base_kwargs, **kwargs}
if isinstance(typ, type): return typ(**merged_kwargs)
merged_kwargs["processor_cls"] = typ
# NOTE: Pythonic dict is not hashable and will raise unhashable type
# error when calling `cached_get_processor`, therefore we need to
# wrap it to a hashable dict.
for key, value in merged_kwargs.items():
if isinstance(value, dict):
merged_kwargs[key] = HashableDict(value)
hf_processor = cached_get_processor(
self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code,
**merged_kwargs,
)
if not isinstance(hf_processor, typ):
raise TypeError("Invalid type of HuggingFace processor. "
f"Expected type: {typ}, but "
f"found type: {type(hf_processor)}")
return hf_processor
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -142,10 +129,10 @@ class InputProcessingContext(InputContext): ...@@ -142,10 +129,10 @@ class InputProcessingContext(InputContext):
def get_hf_processor( def get_hf_processor(
self, self,
typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin, typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
/, /,
**kwargs: object, **kwargs: object,
) -> P: ) -> _P:
return super().get_hf_processor( return super().get_hf_processor(
typ, typ,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
...@@ -341,16 +328,13 @@ class InputRegistry: ...@@ -341,16 +328,13 @@ class InputRegistry:
from vllm.model_executor.model_loader import get_model_architecture from vllm.model_executor.model_loader import get_model_architecture
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.profiling import MultiModalProfiler
from vllm.multimodal.utils import cached_get_tokenizer
if mm_registry.has_processor(model_config): if mm_registry.has_processor(model_config):
tokenizer = cached_get_tokenizer( tokenizer = cached_tokenizer_from_config(model_config)
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
processor = mm_registry.create_processor(model_config, tokenizer) processor = mm_registry.create_processor(model_config, tokenizer)
profiler = MultiModalProfiler(processor) profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_dummy_data(seq_len) dummy_data = profiler.get_dummy_data(
seq_len, is_encoder_data=is_encoder_data)
else: else:
model_cls, _ = get_model_architecture(model_config) model_cls, _ = get_model_architecture(model_config)
if is_encoder_data: if is_encoder_data:
......
...@@ -31,7 +31,7 @@ def get_bad_words_logits_processors( ...@@ -31,7 +31,7 @@ def get_bad_words_logits_processors(
if isinstance(tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens # Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(prompt=prompt) prompt_token_ids = tokenizer.encode(text=prompt)
else: else:
prompt_token_ids = tokenizer.encode(text=prompt, prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False) add_special_tokens=False)
......
...@@ -16,8 +16,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -16,8 +16,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce)
tensor_model_parallel_gather)
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
# yapf: disable # yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -1040,10 +1039,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1040,10 +1039,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = lm_head.linear_method.apply(lm_head, hidden_states) logits = lm_head.quant_method.apply(lm_head, hidden_states)
if embedding_bias is not None: if embedding_bias is not None:
logits += embedding_bias logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Gather logits for TP
logits = self.base_layer._gather_logits(logits)
if logits is None: if logits is None:
return None return None
......
...@@ -5,7 +5,8 @@ import math ...@@ -5,7 +5,8 @@ import math
import os import os
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type,
Union)
import safetensors.torch import safetensors.torch
import torch import torch
...@@ -622,12 +623,14 @@ class LoRAModelManager(AdapterModelManager): ...@@ -622,12 +623,14 @@ class LoRAModelManager(AdapterModelManager):
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items(): for module_name, new_module_names in self.packed_modules.items():
replacement_loras: List[Optional[LoRALayerWeights]] = [] replacement_loras: List[Optional[LoRALayerWeights]] = []
replaced_module: Set[str] = set()
has_replacement = False has_replacement = False
for r in new_module_names: for r in new_module_names:
lora = lora_model.get_lora(r) lora = lora_model.get_lora(r)
replacement_loras.append(lora) replacement_loras.append(lora)
if lora: if lora:
has_replacement = True has_replacement = True
replaced_module.add(r)
if not has_replacement: if not has_replacement:
continue continue
for i in range(len(replacement_loras)): for i in range(len(replacement_loras)):
...@@ -636,6 +639,9 @@ class LoRAModelManager(AdapterModelManager): ...@@ -636,6 +639,9 @@ class LoRAModelManager(AdapterModelManager):
replacement_loras[i] = None replacement_loras[i] = None
lora_model.loras[module_name] = PackedLoRALayerWeights.pack( lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras) replacement_loras)
# Remove the modules that have been replaced.
for module in replaced_module:
lora_model.loras.pop(module, None)
def deactivate_adapter(self, adapter_id: int) -> bool: def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters, return deactivate_adapter(adapter_id, self._active_adapters,
......
# SPDX-License-Identifier: Apache-2.0
"""
Utilities for Punica kernel construction.
"""
import triton
import triton.language as tl
@triton.jit
def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, CAST_TYPE: tl.constexpr,
b_dtype: tl.constexpr):
"""
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
B (k x n), iterate, through the K dimension to compute the partial/complete
matrix block product.
If SPLIT_K == 1, the output m x n product is complete.
If SPLIT_K > 1, the thread block computes partial outputs. The partial
outputs are then atomically summed in the caller code.
Args:
a_ptr: Array of pointers, identifying rows of A
b_ptr: Array of pointers, identifying columns of B
ak_stride: K dimension stride of the A matrix
bk_stride: K dimension stride of the B matrix
K: Length of the K dimension
BLOCK_M: M dimension of the output block m x n
BLOCK_N: N dimension of the output block m x n
BLOCK_K: K dimension atom
EVEN_K: True if the blocks of A and B can be loaded without any
masking.
SPLIT_K: Parameter signifying parallelism in the K dimension.
CAST_TYPE: if True, cast the values from the A matrix to the B
matrix dtype.
b_dtype: datatype of the B matrix
"""
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
tiled_a = tl.load(a_ptr)
tiled_b = tl.load(b_ptr)
else:
tiled_a = tl.load(a_ptr,
mask=offset_k[None, :]
< K - k * (BLOCK_K * SPLIT_K),
other=0)
tiled_b = tl.load(b_ptr,
mask=offset_k[:, None]
< K - k * (BLOCK_K * SPLIT_K),
other=0)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * SPLIT_K * ak_stride
b_ptr += BLOCK_K * SPLIT_K * bk_stride
return accumulator
@triton.jit
def do_expand_kernel(
pid_n,
lora_index,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
M_LEN,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SAME_STRIDE: tl.constexpr,
SLICE_NUM: tl.constexpr,
EVEN_K: tl.constexpr,
CAST_TYPE: tl.constexpr,
ADD_INPUTS: tl.constexpr,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice,
compute the matrix product and store in the appropriate output location.
Given that this is an expand kernel, we don't perform any split-K reduction
as the K dimension is assumed to be small.
"""
# ls_d*_ptr can be either an integer or a pointer
if SAME_STRIDE:
# integer
cur_lora_d0_stride = ls_d0_ptr
cur_lora_d1_stride = ls_d1_ptr
cur_lora_d2_stride = ls_d2_ptr
else:
# pointer
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
# Identify the input_ptr and lora_ptr from slice_id.
if SLICE_NUM == 1:
cur_input_ptr = input_ptr
cur_lora_ptr = lora_ptr
else:
cur_input_ptr = input_ptr + slice_id * input_d0_stride
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(out_ptr.dtype.element_ty))
# Identify the column indices of B to process.
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# Identify A and B block pointers
offset_k = tl.arange(0, BLOCK_K)
a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride +
offset_k[None, :] * input_d2_stride, )
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
offset_k[:, None] * cur_lora_d2_stride +
rbn[None, :] * cur_lora_d1_stride)
# Compute the block matrix product.
SPLIT_K = 1
accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride,
offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K,
CAST_TYPE, cur_lora_ptr.dtype.element_ty)
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
if SLICE_NUM == 1:
cur_slice_start = slice_start_loc
else:
cur_slice_start = tl.load(slice_start_loc + slice_id)
# Identify the C output pointers to store the results of the accumulator.
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
offset_cm = tl.arange(0, BLOCK_M)
c_ptr = (out_ptr + ram[:, None] * output_d0_stride +
offset_cn[None, :] * output_d1_stride)
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :]
< (cur_slice_start + N))
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@triton.jit
def do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_index,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
M_LEN,
ram,
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
SLICE_NUM: tl.constexpr,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice, compute the
matrix product and store in the appropriate output location.
"""
# Identify the lora_ptr from slice_id.
if SLICE_NUM == 1:
# current lora ptr
cur_lora_ptr = lora_ptr
else:
# current lora ptr
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(input_ptr.dtype.element_ty))
# Identify the column indices of B to process.
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# Identify A and B block pointers
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
a_ptr = (input_ptr + ram[:, None] * input_d0_stride +
offset_k[None, :] * input_d1_stride)
b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
rbn[None, :] * lora_d1_stride +
offset_k[:, None] * lora_d2_stride)
# Compute partial/complete block matrix product.
accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_k,
K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, False,
cur_lora_ptr.dtype.element_ty)
# Identify the C output pointers to store the results of the accumulator.
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_cm = tl.arange(0, BLOCK_M)
cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr +
slice_id * output_d0_stride)
c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[
None, :] * output_d2_stride
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
accumulator *= scaling
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
...@@ -14,6 +14,7 @@ import triton.language as tl ...@@ -14,6 +14,7 @@ import triton.language as tl
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from .kernel_utils import do_expand_kernel
from .utils import _get_lora_b_ptr from .utils import _get_lora_b_ptr
...@@ -63,86 +64,56 @@ def _sgmv_expand_kernel( ...@@ -63,86 +64,56 @@ def _sgmv_expand_kernel(
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id) curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
pid_m = pid // cta_n_num pid_m = pid // cta_n_num
pid_n = pid % cta_n_num pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch) M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M: if pid_m * BLOCK_M >= M:
return return
if pid_n * BLOCK_N > curr_N: if pid_n * BLOCK_N >= curr_N:
return return
lora_index = tl.load(lora_indices + cur_batch) lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1: if lora_index == -1:
return return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch) m_offset = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M))
offset_k = tl.arange(0, BLOCK_K) cta_m_offset = m_offset + (pid_m * BLOCK_M)
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) offset_m = tl.arange(0, BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(offset_n % curr_N, BLOCK_N), ram = cta_m_offset + tl.max_contiguous(
BLOCK_N) tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M)
# ls_d*_ptr can be either an integer or a pointer do_expand_kernel(
if SAME_STRIDE: pid_n,
# integer lora_index,
cur_lora_d0_stride = ls_d0_ptr slice_id,
cur_lora_d1_stride = ls_d1_ptr input_ptr,
cur_lora_d2_stride = ls_d2_ptr lora_ptr,
else: out_ptr,
# pointer curr_N,
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) K,
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) cta_m_len,
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) ram, # array identifying the rows of Input ptr to operate on
if SLICE_NUM == 1: slice_start_loc,
cur_input_ptr = input_ptr # input ptr strides
cur_lora_ptr = lora_ptr input_d0_stride,
input_d1_stride,
else: input_d2_stride,
cur_input_ptr = input_ptr + slice_id * input_d0_stride # lora ptr strides
cur_lora_ptr = tl.load(lora_ptr + slice_id).to( ls_d0_ptr,
tl.pointer_type(out_ptr.dtype.element_ty)) ls_d1_ptr,
ls_d2_ptr,
a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + # out ptr strides
ram[:, None] * input_d1_stride + output_d0_stride,
offset_k[None, :] * input_d2_stride, ) output_d1_stride,
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + # constants
offset_k[:, None] * cur_lora_d2_stride + BLOCK_M,
rbn[None, :] * cur_lora_d1_stride) BLOCK_N,
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) BLOCK_K,
for k in range(tl.cdiv(K, BLOCK_K)): SAME_STRIDE,
if EVEN_K: SLICE_NUM,
tiled_a = tl.load(a_ptr) EVEN_K,
tiled_b = tl.load(b_ptr) CAST_TYPE,
else: ADD_INPUTS,
tiled_a = tl.load(a_ptr, )
mask=offset_k[None, :] < K - k * BLOCK_K,
other=0)
tiled_b = tl.load(b_ptr,
mask=offset_k[:, None] < K - k * BLOCK_K,
other=0)
if CAST_TYPE:
tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * input_d2_stride
b_ptr += BLOCK_K * cur_lora_d2_stride
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
if SLICE_NUM == 1:
cur_slice_start = slice_start_loc
else:
cur_slice_start = tl.load(slice_start_loc + slice_id)
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
c_ptr = (out_ptr + offset_cm[:, None] * output_d0_stride +
offset_cn[None, :] * output_d1_stride)
M = tl.load(seq_lens + cur_batch)
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (
offset_cn[None, :] < (cur_slice_start + curr_N))
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@torch.inference_mode() @torch.inference_mode()
......
...@@ -14,6 +14,7 @@ import triton.language as tl ...@@ -14,6 +14,7 @@ import triton.language as tl
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from .kernel_utils import do_shrink_kernel
from .utils import _get_lora_a_ptr from .utils import _get_lora_a_ptr
...@@ -62,67 +63,50 @@ def _sgmv_shrink_kernel( ...@@ -62,67 +63,50 @@ def _sgmv_shrink_kernel(
pid_sk = pid_mix % SPLIT_K pid_sk = pid_mix % SPLIT_K
M = tl.load(seq_lens + cur_batch) M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M: if pid_m * BLOCK_M >= M:
return return
lora_index = tl.load(lora_indices + cur_batch) lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1: if lora_index == -1:
return return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# input ptr
a_ptr = (input_ptr + cur_seq_start * input_d0_stride +
ram[:, None] * input_d0_stride +
offset_k[None, :] * input_d1_stride)
if SLICE_NUM == 1: m_offset = tl.load(b_seq_start_loc + cur_batch)
# current lora ptr
cur_lora_ptr = lora_ptr cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M))
else: cta_m_offset = m_offset + (pid_m * BLOCK_M)
# current lora ptr offset_m = tl.arange(0, BLOCK_M)
cur_lora_ptr = tl.load(lora_ptr + slice_id).to( ram = cta_m_offset + tl.max_contiguous(
tl.pointer_type(input_ptr.dtype.element_ty)) tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M)
b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index + do_shrink_kernel(
rbn[None, :] * lora_d1_stride + pid_n,
offset_k[:, None] * lora_d2_stride) pid_sk,
slice_id,
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) lora_index,
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): input_ptr,
if EVEN_K: lora_ptr,
tiled_a = tl.load(a_ptr) out_ptr,
tiled_b = tl.load(b_ptr) N,
else: K,
k_remaining = K - k * (BLOCK_K * SPLIT_K) cta_m_len,
tiled_a = tl.load(a_ptr, ram,
mask=offset_k[None, :] < k_remaining, # input strides
other=0.0) input_d0_stride,
tiled_b = tl.load(b_ptr, input_d1_stride,
mask=offset_k[:, None] < k_remaining, # lora strides
other=0.0) lora_d0_stride,
accumulator += tl.dot(tiled_a, tiled_b) lora_d1_stride,
lora_d2_stride,
a_ptr += BLOCK_K * SPLIT_K * input_d1_stride # output strides
b_ptr += BLOCK_K * SPLIT_K * lora_d2_stride output_d0_stride,
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M output_d1_stride,
output_d2_stride,
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N scaling,
cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr + BLOCK_M,
slice_id * output_d0_stride) BLOCK_N,
c_ptr = cur_out_ptr + offset_cm[:, None] * output_d1_stride + offset_cn[ BLOCK_K,
None, :] * output_d2_stride EVEN_K,
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] SPLIT_K,
< N) SLICE_NUM)
accumulator *= scaling
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
@torch.inference_mode() @torch.inference_mode()
......
...@@ -147,7 +147,7 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -147,7 +147,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
dtype=torch.long, dtype=torch.long,
device=device) device=device)
# 5 is the number of indicies tensors. # 5 is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded, # base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices,long_lora_indices # embeddings_indices,long_lora_indices
self.indices_len: List[Optional[int]] = [None] * 5 self.indices_len: List[Optional[int]] = [None] * 5
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple, Union, final from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
import torch import torch
from vllm_hpu_extension.ops import (dispatch_bgmv_embedding, from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
dispatch_bgmv_linear) dispatch_bgmv_linear)
from .punica_base import PunicaWrapperBase from .punica_base import PunicaWrapperBase
from .utils import convert_mapping
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
@final @final
...@@ -19,6 +25,55 @@ class PunicaWrapperHPU(PunicaWrapperBase): ...@@ -19,6 +25,55 @@ class PunicaWrapperHPU(PunicaWrapperBase):
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens, PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
max_batches, device) max_batches, device)
def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
):
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_offsets_tensor,
indices_len,
) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size,
extra_vocab_size, self.device, None)
# Updating each element in `long_lora_offsets` with `lora_offset` slows
# down perf in HPU due to a series of `strided_insert` ops during lazy
# graph accumulation. Hence HPU appends `lora_offset` to a list and
# converts it to a tensor only after it is ready.
if long_lora_context:
index_mapping_indices: List[int] = list(
mapping.index_mapping).copy()
long_lora_offsets: List[int] = []
for i in range(len(index_mapping_indices)):
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets.append(lora_offset)
long_lora_offsets_tensor = torch.tensor(long_lora_offsets,
device=self.device,
dtype=torch.long)
indices_len[-1] = long_lora_offsets_tensor.shape[-1]
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
sampler_indices_padded)
self._embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices)
if long_lora_offsets_tensor is not None:
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
long_lora_offsets_tensor)
else:
self._long_lora_indices.zero_()
self.indices_len[:] = indices_len
def add_lora_embedding(self, def add_lora_embedding(self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment