Commit 38d80967 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori

parents 33650733 880c741b
...@@ -37,6 +37,7 @@ if TYPE_CHECKING: ...@@ -37,6 +37,7 @@ if TYPE_CHECKING:
VLLM_CONFIGURE_LOGGING: int = 1 VLLM_CONFIGURE_LOGGING: int = 1
VLLM_LOGGING_LEVEL: str = "INFO" VLLM_LOGGING_LEVEL: str = "INFO"
VLLM_LOGGING_PREFIX: str = "" VLLM_LOGGING_PREFIX: str = ""
VLLM_LOGGING_STREAM: str = "ext://sys.stdout"
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None
VLLM_LOG_STATS_INTERVAL: float = 10. VLLM_LOG_STATS_INTERVAL: float = 10.
...@@ -162,12 +163,18 @@ if TYPE_CHECKING: ...@@ -162,12 +163,18 @@ if TYPE_CHECKING:
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False
VLLM_HAS_FLASHINFER_CUBIN: bool = False VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
def get_default_cache_root(): def get_default_cache_root():
...@@ -235,7 +242,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -235,7 +242,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ================== # ================== Installation Time Env Vars ==================
# Target device of vLLM, supporting [cuda (by default), # Target device of vLLM, supporting [cuda (by default),
# rocm, neuron, cpu] # rocm, cpu]
"VLLM_TARGET_DEVICE": "VLLM_TARGET_DEVICE":
lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(), lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(),
...@@ -431,6 +438,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -431,6 +438,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LOGGING_LEVEL": "VLLM_LOGGING_LEVEL":
lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(), lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(),
# this is used for configuring the default logging stream
"VLLM_LOGGING_STREAM":
lambda: os.getenv("VLLM_LOGGING_STREAM", "ext://sys.stdout"),
# if set, VLLM_LOGGING_PREFIX will be prepended to all log messages # if set, VLLM_LOGGING_PREFIX will be prepended to all log messages
"VLLM_LOGGING_PREFIX": "VLLM_LOGGING_PREFIX":
lambda: os.getenv("VLLM_LOGGING_PREFIX", ""), lambda: os.getenv("VLLM_LOGGING_PREFIX", ""),
...@@ -463,6 +474,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -463,6 +474,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "ROCM_FLASH": use ROCmFlashAttention # - "ROCM_FLASH": use ROCmFlashAttention
# - "FLASHINFER": use flashinfer # - "FLASHINFER": use flashinfer
# - "FLASHMLA": use FlashMLA # - "FLASHMLA": use FlashMLA
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
"VLLM_ATTENTION_BACKEND": "VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
...@@ -994,6 +1006,15 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -994,6 +1006,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))), lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))),
# If set to 1, use the FlashInfer CUTLASS backend for
# MXFP8 (activation) x MXFP4 (weight) MoE.
# This is separate from the TRTLLMGEN path controlled by
# VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8.
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS":
lambda: bool(int(
os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "0")
)),
# If set to 1, use the FlashInfer # If set to 1, use the FlashInfer
# BF16 (activation) x MXFP4 (weight) MoE backend. # BF16 (activation) x MXFP4 (weight) MoE backend.
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16":
...@@ -1063,7 +1084,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1063,7 +1084,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm should use flashinfer fused allreduce. The variable should be a # vllm should use flashinfer fused allreduce. The variable should be a
# JSON with the following format: # JSON with the following format:
# { <world size>: <max size in mb> } # { <world size>: <max size in mb> }
# Unspecified world sizes will fallback to # Unspecified world sizes will fall back to
# { 2: 64, 4: 1, <everything else>: 0.5 } # { 2: 64, 4: 1, <everything else>: 0.5 }
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB": "VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB":
lambda: json.loads(os.getenv( lambda: json.loads(os.getenv(
...@@ -1135,6 +1156,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1135,6 +1156,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRTLLM_ATTENTION": "VLLM_USE_TRTLLM_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
# If set to 1, when we use fp8 kv, we do not quantize Q to fp8
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION":
lambda: bool(int(os.getenv("VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "0"))),
# If set, it means we pre-downloaded cubin files and flashinfer will # If set, it means we pre-downloaded cubin files and flashinfer will
# read the cubin files directly. # read the cubin files directly.
"VLLM_HAS_FLASHINFER_CUBIN": "VLLM_HAS_FLASHINFER_CUBIN":
...@@ -1199,6 +1224,23 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1199,6 +1224,23 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TUNED_CONFIG_FOLDER": "VLLM_TUNED_CONFIG_FOLDER":
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
# Allows vllm use container tool
"VLLM_GPT_OSS_USE_CONTAINER_TOOL":
lambda: bool(int(os.getenv("VLLM_GPT_OSS_USE_CONTAINER_TOOL", "0"))),
# Allows harmony instructions to be injected on system messages
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS":
lambda: bool(
int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0"))),
# Add optional custom scopes for profiling, disable to avoid overheads
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
# Represent block hashes in KV cache events as 64-bit integers instead of
# raw bytes. Defaults to True for backward compatibility.
"VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES":
lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
...@@ -1269,9 +1311,11 @@ def compute_hash() -> str: ...@@ -1269,9 +1311,11 @@ def compute_hash() -> str:
"VLLM_USE_FLASHINFER_MOE_FP8", "VLLM_USE_FLASHINFER_MOE_FP8",
"VLLM_USE_FLASHINFER_MOE_FP4", "VLLM_USE_FLASHINFER_MOE_FP4",
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
"VLLM_USE_CUDNN_PREFILL", "VLLM_USE_CUDNN_PREFILL",
"VLLM_USE_TRTLLM_ATTENTION", "VLLM_USE_TRTLLM_ATTENTION",
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
"VLLM_ROCM_USE_AITER", "VLLM_ROCM_USE_AITER",
"VLLM_ROCM_USE_AITER_PAGED_ATTN", "VLLM_ROCM_USE_AITER_PAGED_ATTN",
"VLLM_ROCM_USE_AITER_LINEAR", "VLLM_ROCM_USE_AITER_LINEAR",
......
...@@ -231,7 +231,7 @@ class ExecutorBase(ABC): ...@@ -231,7 +231,7 @@ class ExecutorBase(ABC):
def shutdown(self) -> None: def shutdown(self) -> None:
"""Shutdown the executor.""" """Shutdown the executor."""
return self.collective_rpc("shutdown")
def __del__(self): def __del__(self):
self.shutdown() self.shutdown()
......
...@@ -117,10 +117,12 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -117,10 +117,12 @@ class RayDistributedExecutor(DistributedExecutorBase):
self.driver_worker.execute_method) self.driver_worker.execute_method)
def shutdown(self) -> None: def shutdown(self) -> None:
logger.info( if logger:
"Shutting down Ray distributed executor. If you see error log " # Somehow logger can be None here.
"from logging.cc regarding SIGTERM received, please ignore because " logger.info(
"this is the expected termination process in Ray.") "Shutting down Ray distributed executor. If you see error log "
"from logging.cc regarding SIGTERM received, please ignore "
"because this is the expected termination process in Ray.")
if hasattr(self, "forward_dag") and self.forward_dag is not None: if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown() self.forward_dag.teardown()
import ray import ray
......
...@@ -223,7 +223,7 @@ def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): ...@@ -223,7 +223,7 @@ def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
""" """
# Wait until PG is ready - this will block until all # Wait until PG is ready - this will block until all
# requested resources are available, and will timeout # requested resources are available, and will time out
# if they cannot be provisioned. # if they cannot be provisioned.
placement_group_specs = current_placement_group.bundle_specs placement_group_specs = current_placement_group.bundle_specs
......
...@@ -71,6 +71,10 @@ class UniProcExecutor(ExecutorBase): ...@@ -71,6 +71,10 @@ class UniProcExecutor(ExecutorBase):
self.shutdown() self.shutdown()
return return
def shutdown(self) -> None:
if worker := self.driver_worker:
worker.shutdown()
UniProcExecutorAsync = UniProcExecutor UniProcExecutorAsync = UniProcExecutor
......
...@@ -52,6 +52,9 @@ class TokensPrompt(TypedDict): ...@@ -52,6 +52,9 @@ class TokensPrompt(TypedDict):
prompt_token_ids: list[int] prompt_token_ids: list[int]
"""A list of token IDs to pass to the model.""" """A list of token IDs to pass to the model."""
prompt: NotRequired[str]
"""The prompt text corresponding to the token IDs, if available."""
token_type_ids: NotRequired[list[int]] token_type_ids: NotRequired[list[int]]
"""A list of token type IDs to pass to the cross encoder model.""" """A list of token type IDs to pass to the cross encoder model."""
......
...@@ -258,8 +258,7 @@ class InputPreprocessor: ...@@ -258,8 +258,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Apply the model's multi-modal processor to a multi-modal prompt, Apply the model's multi-modal processor to a multi-modal prompt,
...@@ -276,13 +275,23 @@ class InputPreprocessor: ...@@ -276,13 +275,23 @@ class InputPreprocessor:
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
return mm_processor.apply( mm_input = mm_processor.apply(
prompt, prompt,
mm_data, mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
mm_hashes = mm_input["mm_hashes"]
# Validate that all mm items have a string as their hash
if not contains_only_strings(mm_hashes):
raise ValueError(
f"mm_hashes must contain only strings, got: {mm_hashes}. "
"This is likely due to an incorrect custom implementation of "
"MultiModalProcessor.apply method.")
return mm_input
async def _process_multimodal_async( async def _process_multimodal_async(
self, self,
...@@ -292,8 +301,7 @@ class InputPreprocessor: ...@@ -292,8 +301,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Async version of Async version of
...@@ -310,13 +318,23 @@ class InputPreprocessor: ...@@ -310,13 +318,23 @@ class InputPreprocessor:
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
return mm_processor.apply( mm_input = mm_processor.apply(
prompt, prompt,
mm_data, mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
mm_hashes = mm_input["mm_hashes"]
# Validate that all mm items have a string as their hash
if not contains_only_strings(mm_hashes):
raise ValueError(
f"mm_hashes must contain only strings, got: {mm_hashes}. "
"This is likely due to an incorrect custom implementation of "
"MultiModalProcessor.apply method.")
return mm_input
def _process_embeds( def _process_embeds(
self, self,
...@@ -370,8 +388,7 @@ class InputPreprocessor: ...@@ -370,8 +388,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = self._truncate_inputs( prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs) parsed_content["prompt_token_ids"], tokenization_kwargs)
...@@ -384,7 +401,7 @@ class InputPreprocessor: ...@@ -384,7 +401,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
else: else:
inputs = token_inputs(prompt_token_ids=prompt_token_ids) inputs = token_inputs(prompt_token_ids=prompt_token_ids)
...@@ -400,8 +417,7 @@ class InputPreprocessor: ...@@ -400,8 +417,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = self._truncate_inputs( prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs) parsed_content["prompt_token_ids"], tokenization_kwargs)
...@@ -414,7 +430,7 @@ class InputPreprocessor: ...@@ -414,7 +430,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
else: else:
inputs = token_inputs(prompt_token_ids=prompt_token_ids, ) inputs = token_inputs(prompt_token_ids=prompt_token_ids, )
...@@ -430,8 +446,7 @@ class InputPreprocessor: ...@@ -430,8 +446,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"] prompt_text = parsed_content["prompt"]
...@@ -443,7 +458,7 @@ class InputPreprocessor: ...@@ -443,7 +458,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
else: else:
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
...@@ -467,8 +482,7 @@ class InputPreprocessor: ...@@ -467,8 +482,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"] prompt_text = parsed_content["prompt"]
...@@ -480,7 +494,7 @@ class InputPreprocessor: ...@@ -480,7 +494,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
else: else:
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
...@@ -504,8 +518,7 @@ class InputPreprocessor: ...@@ -504,8 +518,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> SingletonInputs: ) -> SingletonInputs:
""" """
Extract the singleton inputs from a prompt. Extract the singleton inputs from a prompt.
...@@ -527,21 +540,21 @@ class InputPreprocessor: ...@@ -527,21 +540,21 @@ class InputPreprocessor:
return self._process_tokens( return self._process_tokens(
parsed["content"], parsed["content"],
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if parsed["type"] == "text": if parsed["type"] == "text":
return self._process_text( return self._process_text(
parsed["content"], parsed["content"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if parsed["type"] == "str": if parsed["type"] == "str":
return self._process_text( return self._process_text(
TextPrompt(prompt=parsed["content"]), TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
assert_never(parsed) assert_never(parsed)
...@@ -552,8 +565,7 @@ class InputPreprocessor: ...@@ -552,8 +565,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> SingletonInputs: ) -> SingletonInputs:
""" """
Async version of Async version of
...@@ -567,21 +579,21 @@ class InputPreprocessor: ...@@ -567,21 +579,21 @@ class InputPreprocessor:
return await self._process_tokens_async( return await self._process_tokens_async(
parsed["content"], parsed["content"],
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if parsed["type"] == "text": if parsed["type"] == "text":
return await self._process_text_async( return await self._process_text_async(
parsed["content"], parsed["content"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if parsed["type"] == "str": if parsed["type"] == "str":
return await self._process_text_async( return await self._process_text_async(
TextPrompt(prompt=parsed["content"]), TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
assert_never(parsed) assert_never(parsed)
...@@ -692,8 +704,7 @@ class InputPreprocessor: ...@@ -692,8 +704,7 @@ class InputPreprocessor:
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
""" """
For encoder/decoder models only: For encoder/decoder models only:
...@@ -735,7 +746,7 @@ class InputPreprocessor: ...@@ -735,7 +746,7 @@ class InputPreprocessor:
encoder_inputs = self._prompt_to_llm_inputs( encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"], prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None decoder_inputs = None
...@@ -751,7 +762,7 @@ class InputPreprocessor: ...@@ -751,7 +762,7 @@ class InputPreprocessor:
inputs = self._prompt_to_llm_inputs( inputs = self._prompt_to_llm_inputs(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
...@@ -768,8 +779,7 @@ class InputPreprocessor: ...@@ -768,8 +779,7 @@ class InputPreprocessor:
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
""" """
Async version of Async version of
...@@ -782,7 +792,7 @@ class InputPreprocessor: ...@@ -782,7 +792,7 @@ class InputPreprocessor:
encoder_task = self._prompt_to_llm_inputs_async( encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"], prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
...@@ -792,7 +802,7 @@ class InputPreprocessor: ...@@ -792,7 +802,7 @@ class InputPreprocessor:
decoder_task = self._prompt_to_llm_inputs_async( decoder_task = self._prompt_to_llm_inputs_async(
decoder_input, decoder_input,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
encoder_inputs, decoder_inputs = await asyncio.gather( encoder_inputs, decoder_inputs = await asyncio.gather(
...@@ -808,7 +818,7 @@ class InputPreprocessor: ...@@ -808,7 +818,7 @@ class InputPreprocessor:
inputs = await self._prompt_to_llm_inputs_async( inputs = await self._prompt_to_llm_inputs_async(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
...@@ -836,8 +846,7 @@ class InputPreprocessor: ...@@ -836,8 +846,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
For decoder-only models: For decoder-only models:
...@@ -858,7 +867,7 @@ class InputPreprocessor: ...@@ -858,7 +867,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
return self._build_decoder_only_llm_inputs(prompt_comps) return self._build_decoder_only_llm_inputs(prompt_comps)
...@@ -869,8 +878,7 @@ class InputPreprocessor: ...@@ -869,8 +878,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
Async version of Async version of
...@@ -880,7 +888,7 @@ class InputPreprocessor: ...@@ -880,7 +888,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
return self._build_decoder_only_llm_inputs(prompt_comps) return self._build_decoder_only_llm_inputs(prompt_comps)
...@@ -891,8 +899,7 @@ class InputPreprocessor: ...@@ -891,8 +899,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
"""Preprocess the input prompt.""" """Preprocess the input prompt."""
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
...@@ -901,7 +908,7 @@ class InputPreprocessor: ...@@ -901,7 +908,7 @@ class InputPreprocessor:
return self._process_encoder_decoder_prompt( return self._process_encoder_decoder_prompt(
prompt, prompt,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
...@@ -913,7 +920,7 @@ class InputPreprocessor: ...@@ -913,7 +920,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
async def preprocess_async( async def preprocess_async(
...@@ -922,8 +929,7 @@ class InputPreprocessor: ...@@ -922,8 +929,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
""" """
Async version of Async version of
...@@ -935,7 +941,7 @@ class InputPreprocessor: ...@@ -935,7 +941,7 @@ class InputPreprocessor:
return await self._process_encoder_decoder_prompt_async( return await self._process_encoder_decoder_prompt_async(
prompt, prompt,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
...@@ -947,9 +953,21 @@ class InputPreprocessor: ...@@ -947,9 +953,21 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
def clear_cache(self) -> None: def clear_cache(self) -> None:
if self.mm_processor_cache is not None: if self.mm_processor_cache is not None:
self.mm_processor_cache.clear_cache() self.mm_processor_cache.clear_cache()
# Helper function to validate that a nested dictionary contains
# only strings or list of strings as the leaf values.
def contains_only_strings(obj: object):
if isinstance(obj, str):
return True
if isinstance(obj, list):
return all(isinstance(x, str) for x in obj)
if isinstance(obj, dict):
return all(contains_only_strings(v) for v in obj.values())
return False
...@@ -20,9 +20,10 @@ VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING ...@@ -20,9 +20,10 @@ VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL
VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX
VLLM_LOGGING_STREAM = envs.VLLM_LOGGING_STREAM
_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s " _FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s "
"[%(filename)s:%(lineno)d] %(message)s") "[%(fileinfo)s:%(lineno)d] %(message)s")
_DATE_FORMAT = "%m-%d %H:%M:%S" _DATE_FORMAT = "%m-%d %H:%M:%S"
DEFAULT_LOGGING_CONFIG = { DEFAULT_LOGGING_CONFIG = {
...@@ -38,7 +39,7 @@ DEFAULT_LOGGING_CONFIG = { ...@@ -38,7 +39,7 @@ DEFAULT_LOGGING_CONFIG = {
"class": "logging.StreamHandler", "class": "logging.StreamHandler",
"formatter": "vllm", "formatter": "vllm",
"level": VLLM_LOGGING_LEVEL, "level": VLLM_LOGGING_LEVEL,
"stream": "ext://sys.stdout", "stream": VLLM_LOGGING_STREAM,
}, },
}, },
"loggers": { "loggers": {
......
...@@ -2,16 +2,77 @@ ...@@ -2,16 +2,77 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging import logging
from pathlib import Path
from vllm import envs
class NewLineFormatter(logging.Formatter): class NewLineFormatter(logging.Formatter):
"""Adds logging prefix to newlines to align multi-line messages.""" """Adds logging prefix to newlines to align multi-line messages."""
def __init__(self, fmt, datefmt=None, style="%"): def __init__(self, fmt, datefmt=None, style="%"):
logging.Formatter.__init__(self, fmt, datefmt, style) super().__init__(fmt, datefmt, style)
self.use_relpath = envs.VLLM_LOGGING_LEVEL == "DEBUG"
if self.use_relpath:
self.root_dir = Path(__file__).resolve().parent.parent.parent
def format(self, record): def format(self, record):
msg = logging.Formatter.format(self, record)
def shrink_path(relpath: Path) -> str:
"""
Shortens a file path for logging display:
- Removes leading 'vllm' folder if present.
- If path starts with 'v1',
keeps the first two and last two levels,
collapsing the middle as '...'.
- Otherwise, keeps the first and last two levels,
collapsing the middle as '...'.
- If the path is short, returns it as-is.
- Examples:
vllm/model_executor/layers/quantization/utils/fp8_utils.py ->
model_executor/.../quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/awq.py ->
model_executor/layers/quantization/awq.py
vllm/v1/attention/backends/mla/common.py ->
v1/attention/backends/mla/common.py
Args:
relpath (Path): The relative path to be shortened.
Returns:
str: The shortened path string for display.
"""
parts = list(relpath.parts)
new_parts = []
if parts and parts[0] == "vllm":
parts = parts[1:]
if parts and parts[0] == "v1":
new_parts += parts[:2]
parts = parts[2:]
elif parts:
new_parts += parts[:1]
parts = parts[1:]
if len(parts) > 2:
new_parts += ["..."] + parts[-2:]
else:
new_parts += parts
return "/".join(new_parts)
if self.use_relpath:
abs_path = getattr(record, "pathname", None)
if abs_path:
try:
relpath = Path(abs_path).resolve().relative_to(
self.root_dir)
except Exception:
relpath = Path(record.filename)
else:
relpath = Path(record.filename)
record.fileinfo = shrink_path(relpath)
else:
record.fileinfo = record.filename
msg = super().format(record)
if record.message != "": if record.message != "":
parts = msg.split(record.message) parts = msg.split(record.message)
msg = msg.replace("\n", "\r\n" + parts[0]) msg = msg.replace("\n", "\r\n" + parts[0])
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
@dataclass
class Logprob:
"""Infos for supporting OpenAI compatible logprobs and token ranks.
Attributes:
logprob: The logprob of chosen token
rank: The vocab rank of chosen token (>=1)
decoded_token: The decoded chosen token index
"""
logprob: float
rank: Optional[int] = None
decoded_token: Optional[str] = None
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs = list[Optional[dict[int, Logprob]]]
# {token_id -> logprob} for each sequence group.
SampleLogprobs = list[dict[int, Logprob]]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.layers.column_parallel_linear import (
ColumnParallelLinearWithLoRA, ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA)
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
from vllm.lora.layers.row_parallel_linear import (
RowParallelLinearWithLoRA, RowParallelLinearWithShardedLoRA)
from vllm.lora.layers.utils import LoRAMapping
from vllm.lora.layers.vocal_parallel_embedding import (
VocabParallelEmbeddingWithLoRA)
__all__ = [
"BaseLayerWithLoRA",
"VocabParallelEmbeddingWithLoRA",
"LogitsProcessorWithLoRA",
"ColumnParallelLinearWithLoRA",
"ColumnParallelLinearWithShardedLoRA",
"MergedColumnParallelLinearWithLoRA",
"MergedColumnParallelLinearWithShardedLoRA",
"MergedQKVParallelLinearWithLoRA",
"MergedQKVParallelLinearWithShardedLoRA",
"QKVParallelLinearWithLoRA",
"QKVParallelLinearWithShardedLoRA",
"RowParallelLinearWithLoRA",
"RowParallelLinearWithShardedLoRA",
"ReplicatedLinearWithLoRA",
"LoRAMapping",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional, Union
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
if TYPE_CHECKING:
from vllm.lora.punica_wrapper import PunicaWrapperBase
class BaseLayerWithLoRA(nn.Module):
def slice_lora_a(
self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
"""Slice lora a if splitting for tensor parallelism."""
...
def slice_lora_b(
self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
"""Slice lora b if splitting with tensor parallelism."""
...
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
"""Initializes lora matrices."""
...
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
...
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
"""Overwrites lora tensors at index."""
...
def set_mapping(
self,
punica_wrapper,
):
self.punica_wrapper: PunicaWrapperBase = punica_wrapper
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, cast
import torch
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.distributed.utils import divide
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, ReplicatedLinear,
RowParallelLinear)
from vllm.platforms import current_platform
from .base import BaseLayerWithLoRA
from .utils import _get_lora_device
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: LinearBase):
super().__init__()
self.base_layer = base_layer
self.input_size = self.base_layer.input_size
self.device = _get_lora_device(self.base_layer)
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
self.output_slices: tuple[int, ...]
self.tp_size: int
self.output_size: int
self.n_slices: int
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
self.lora_config = lora_config
#
if isinstance(self.base_layer, ReplicatedLinear):
lora_a_out_size = lora_config.max_lora_rank
lora_b_out_size = self.output_size
elif isinstance(self.base_layer, ColumnParallelLinear):
lora_a_out_size = (lora_config.max_lora_rank if
not lora_config.fully_sharded_loras else divide(
lora_config.max_lora_rank, self.tp_size))
lora_b_out_size = self.output_size
elif isinstance(self.base_layer, RowParallelLinear):
lora_a_out_size = lora_config.max_lora_rank
lora_b_out_size = (self.output_size if
not lora_config.fully_sharded_loras else divide(
self.output_size, self.tp_size))
else:
raise NotImplementedError
self.lora_a_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_a_out_size,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
) for _ in range(self.n_slices))
self.lora_b_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_b_out_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
) for _ in range(self.n_slices))
if lora_config.bias_enabled:
lora_bias_out_size = lora_b_out_size
self.lora_bias_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_bias_out_size,
dtype=lora_config.lora_dtype,
device=self.device,
) for _ in range(self.n_slices))
self.output_slices = (self.lora_b_stacked[0].shape[2], )
def reset_lora(self, index: int):
for s_index in range(self.n_slices):
self.lora_a_stacked[s_index][index] = 0
self.lora_b_stacked[s_index][index] = 0
if self.lora_config.bias_enabled:
# Make mypy happy
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked)
self.lora_bias_stacked[s_index][index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
lora_bias: Optional[torch.Tensor] = None,
):
# Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
# store weights in a tuple of size 1. These two layers will
# override this function.
assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) ==
self.n_slices == 1)
self.reset_lora(index)
if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
if lora_bias is not None:
lora_bias = self.slice_bias(lora_bias)
self.lora_a_stacked[0][index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[0][index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if lora_bias is not None:
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked)
assert len(self.lora_bias_stacked)
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
lora_bias.T, non_blocking=True)
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
# In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
if x.ndim == 3 and output.ndim == 3:
output = output.flatten(0, 1)
x = x.flatten(0, 1)
lora_output: Optional[
torch.Tensor] = self.punica_wrapper.add_lora_linear(
output, x, self.lora_a_stacked, self.lora_b_stacked,
self.lora_bias_stacked, 1.0, self.output_slices)
if not current_platform.can_update_inplace():
output = lora_output
return output
@property
def weight(self) -> torch.Tensor:
# unquantizedLinear
if hasattr(self.base_layer, "weight"):
return self.base_layer.weight
# Compressed Tensor
elif hasattr(self.base_layer, "weight_packed"):
return self.base_layer.weight_packed
# GPTQ/AWQ
elif hasattr(self.base_layer, "qweight"):
return self.base_layer.qweight
# marlin
elif hasattr(self.base_layer, "B"):
return self.base_layer.B
# HQQ marlin
elif hasattr(self.base_layer, "W_q"):
return self.base_layer.W_q
else:
raise ValueError(f"Unsupported base layer: {self.base_layer}")
@property
def bias(self) -> Optional[torch.Tensor]:
if hasattr(self.base_layer, "bias"):
return self.base_layer.bias
else:
return None
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# pylint: disable=unused-argument from typing import Optional, Union, cast
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.adapter_commons.layers import AdapterMapping from vllm.config.lora import LoRAConfig
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, tensor_model_parallel_all_gather)
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear)
ReplicatedLinear,
RowParallelLinear)
# yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.platforms import current_platform from vllm.platforms import current_platform
if TYPE_CHECKING: from .base_linear import BaseLinearLayerWithLoRA
from vllm.lora.punica_wrapper import PunicaWrapperBase from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace
def _get_lora_device(base_layer: nn.Module) -> torch.device:
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
"""Returns the device for where to place the LoRA tensors."""
# unquantizedLinear
if hasattr(base_layer, "weight"):
return base_layer.weight.device
# Compressed Tensor
elif hasattr(base_layer, "weight_packed"):
return base_layer.weight_packed.device
# GPTQ/AWQ
elif hasattr(base_layer, "qweight"):
return base_layer.qweight.device
# HQQ marlin
elif hasattr(base_layer, "W_q"):
return base_layer.W_q.device
else:
raise ValueError(f"Unsupported base layer: {base_layer}")
def _not_fully_sharded_can_replace(can_replace):
"""
decorator which adds the condition of not using fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
condition = (not kwargs["lora_config"].fully_sharded_loras
if decorate else True)
return can_replace(*args, **kwargs) and condition
return dec
@dataclass
class LoRAMapping(AdapterMapping):
is_prefill: bool = False
class BaseLayerWithLoRA(nn.Module):
def slice_lora_a(
self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
"""Slice lora a if splitting for tensor parallelism."""
...
def slice_lora_b(
self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
"""Slice lora b if splitting with tensor parallelism."""
...
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
"""Initializes lora matrices."""
...
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
...
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
"""Overwrites lora tensors at index."""
...
def set_mapping(
self,
punica_wrapper,
):
self.punica_wrapper: PunicaWrapperBase = punica_wrapper
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
raise NotImplementedError
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.embeddings_slice: Optional[tuple[int, int]]
self.embeddings_weights: Optional[torch.Tensor]
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
if self.base_layer.num_added_embeddings_per_partition > 0:
# We can start adding lora weights
self.embeddings_weights = self.base_layer.weight.data[
self.base_layer.num_org_embeddings_per_partition:self.
base_layer.num_org_embeddings_per_partition +
self.base_layer.num_added_embeddings_per_partition]
self.embeddings_slice = (
self.base_layer.shard_indices.added_vocab_start_index -
self.base_layer.org_vocab_size,
self.base_layer.shard_indices.added_vocab_end_index -
self.base_layer.org_vocab_size)
self.base_layer.weight.data[
self.base_layer.num_org_embeddings_per_partition:].fill_(0)
else:
self.embeddings_slice = None
self.embeddings_weights = None
self.embeddings_tensors = torch.zeros(
(
max_loras,
lora_config.lora_extra_vocab_size,
self.base_layer.embedding_dim,
),
dtype=self.base_layer.weight.dtype,
device=self.base_layer.weight.device,
)
self.lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.org_vocab_size +
lora_config.lora_extra_vocab_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
self.base_layer.embedding_dim,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_a_stacked_2d = self.lora_a_stacked.view(
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2],
)
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index,
:embeddings_tensor.shape[0],
:embeddings_tensor.shape[1],
].copy_(embeddings_tensor, non_blocking=True)
if self.embeddings_slice is not None:
# TODO(yard1): Optimize this copy, we don't need to copy
# everything, just the modified part
embeddings = self.embeddings_tensors.view(
self.embeddings_tensors.shape[0] *
self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2],
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
assert self.embeddings_weights is not None
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1,
1, 0)
# NB: Don't use torch.narrow here. torch.narrow triggers some
# Dynamic Shape specialization in torch.compile
num_tokens = x.shape[0]
indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens]
full_lora_a_embeddings = F.embedding(
x + indices_1,
self.lora_a_stacked_2d,
)
full_output = self.base_layer.forward(x +
(indices_0 * added_tokens_mask))
full_output_org = full_output
if full_output.ndim == 3:
full_output = full_output.view(
full_output.shape[0] * full_output.shape[1], -1)
if full_lora_a_embeddings.ndim == 3:
full_lora_a_embeddings = full_lora_a_embeddings.view(
full_lora_a_embeddings.shape[0] *
full_lora_a_embeddings.shape[1],
-1,
)
lora_output: Optional[
torch.Tensor] = self.punica_wrapper.add_lora_embedding(
full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)
if not current_platform.can_update_inplace():
full_output = lora_output
return full_output.view_as(full_output_org)
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is VocabParallelEmbedding
@property
def weight(self):
return self.base_layer.weight
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: LinearBase):
super().__init__()
self.base_layer = base_layer
self.input_size = self.base_layer.input_size
self.device = _get_lora_device(self.base_layer)
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
self.output_slices: tuple[int, ...]
self.tp_size: int
self.output_size: int
self.n_slices: int
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
self.lora_config = lora_config
#
if isinstance(self.base_layer, ReplicatedLinear):
lora_a_out_size = lora_config.max_lora_rank
lora_b_out_size = self.output_size
elif isinstance(self.base_layer, ColumnParallelLinear):
lora_a_out_size = (lora_config.max_lora_rank if
not lora_config.fully_sharded_loras else divide(
lora_config.max_lora_rank, self.tp_size))
lora_b_out_size = self.output_size
elif isinstance(self.base_layer, RowParallelLinear):
lora_a_out_size = lora_config.max_lora_rank
lora_b_out_size = (self.output_size if
not lora_config.fully_sharded_loras else divide(
self.output_size, self.tp_size))
else:
raise NotImplementedError
self.lora_a_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_a_out_size,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
) for _ in range(self.n_slices))
self.lora_b_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_b_out_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
) for _ in range(self.n_slices))
if lora_config.bias_enabled:
lora_bias_out_size = lora_b_out_size
self.lora_bias_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_bias_out_size,
dtype=lora_config.lora_dtype,
device=self.device,
) for _ in range(self.n_slices))
self.output_slices = (self.lora_b_stacked[0].shape[2], )
def reset_lora(self, index: int):
for s_index in range(self.n_slices):
self.lora_a_stacked[s_index][index] = 0
self.lora_b_stacked[s_index][index] = 0
if self.lora_config.bias_enabled:
# Make mypy happy
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked)
self.lora_bias_stacked[s_index][index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
lora_bias: Optional[torch.Tensor] = None,
):
# Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
# store weights in a tuple of size 1. These two layers will
# override this function.
assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) ==
self.n_slices == 1)
self.reset_lora(index)
if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
if lora_bias is not None:
lora_bias = self.slice_bias(lora_bias)
self.lora_a_stacked[0][index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[0][index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if lora_bias is not None:
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked)
assert len(self.lora_bias_stacked)
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
lora_bias.T, non_blocking=True)
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
# In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
if x.ndim == 3 and output.ndim == 3:
output = output.flatten(0, 1)
x = x.flatten(0, 1)
lora_output: Optional[
torch.Tensor] = self.punica_wrapper.add_lora_linear(
output, x, self.lora_a_stacked, self.lora_b_stacked,
self.lora_bias_stacked, 1.0, self.output_slices)
if not current_platform.can_update_inplace():
output = lora_output
return output
@property
def weight(self) -> torch.Tensor:
# unquantizedLinear
if hasattr(self.base_layer, "weight"):
return self.base_layer.weight
# Compressed Tensor
elif hasattr(self.base_layer, "weight_packed"):
return self.base_layer.weight_packed
# GPTQ/AWQ
elif hasattr(self.base_layer, "qweight"):
return self.base_layer.qweight
# marlin
elif hasattr(self.base_layer, "B"):
return self.base_layer.B
# HQQ marlin
elif hasattr(self.base_layer, "W_q"):
return self.base_layer.W_q
else:
raise ValueError(f"Unsupported base layer: {self.base_layer}")
@property
def bias(self) -> Optional[torch.Tensor]:
if hasattr(self.base_layer, "bias"):
return self.base_layer.bias
else:
return None
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
"""
For `ColumnParallelLinearWithLoRA` or classes that inherit from
`ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
"""
assert (layer.n_slices == len(layer.lora_a_stacked) == len(
layer.lora_b_stacked) == len(layer.output_slices))
if layer.lora_bias_stacked is not None:
assert layer.n_slices == len(layer.lora_bias_stacked)
def __init__(self, base_layer: ReplicatedLinear) -> None: output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
super().__init__(base_layer, )
# To ensure interface compatibility, set to 1 always.
self.tp_size = 1
self.output_size = self.base_layer.output_size
self.n_slices = 1
def forward( x = x.view(-1, x.shape[-1])
self, input_: torch.Tensor output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Forward of ReplicatedLinearWithLoRA
Args: # Since communication is needed, the buffer is directly initialized as a
input_: Tensor whose last dimension is `input_size`. # tensor rather than a tuple of tensor.
buffers = torch.zeros(
(layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
Returns: shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink(
- output buffers, x, layer.lora_a_stacked, 1.0)
- bias
"""
bias = (self.base_layer.bias
if not self.base_layer.skip_bias_add else None)
# Matrix multiply. if not current_platform.can_update_inplace():
output = self.apply(input_, bias) buffers = shrunk_buffers
output_bias = (self.base_layer.bias buffers = tensor_model_parallel_all_gather(buffers)
if self.base_layer.skip_bias_add else None)
if not self.base_layer.return_bias: lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand(
return output output,
buffers,
layer.lora_b_stacked,
layer.lora_bias_stacked,
layer.output_slices,
offset_start=0,
add_input=True)
return output, output_bias if not current_platform.can_update_inplace():
output = lora_output
# ReplicatedLinear should always be replaced, regardless of the fully output = output.view(*out_orig_shape)
# sharded LoRAs setting, because it is, by definition, copied per GPU. # now have column partitioned and packed output
@classmethod return output
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is ReplicatedLinear
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
...@@ -876,84 +444,37 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): ...@@ -876,84 +444,37 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
and len(packed_modules_list) == 3) and len(packed_modules_list) == 3)
#TODO: Implement this # These following layers are based on the tensor parallelism strategy given in
class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA): # Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
pass # https://arxiv.org/abs/2311.03285.
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None: class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
super().__init__(base_layer) """
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
self.tp_size = get_tensor_model_parallel_world_size()
# reset input_size
self.input_size = self.base_layer.input_size_per_partition
self.output_size = self.base_layer.output_size
self.tp_rank = get_tensor_model_parallel_rank() Based on S-LoRA, slicing happens along the rank dim.
# There is only one LoRA layer. """
self.n_slices = 1
# For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
# their `lora_a` and `lora_b` have different sharding patterns. After
# completing the `lora_a` GEMM , a gather operation is performed.
# Therefore, the sharding of `lora_a` only needs to correspond with the
# gather operation.
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.input_size shard_size = self.lora_a_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size start_idx = tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size lora_a = lora_a[:, start_idx:start_idx + shard_size]
lora_a = lora_a[start_idx:end_idx, :]
return lora_a return lora_a
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def apply(self,
return lora_b x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: return _mcp_apply(x, bias, self)
return bias
def forward(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Forward of RowParallelLinear
Args:
input_: tensor whose last dimension is `input_size`. If
`input_is_parallel` is set, then the last dimension
is `input_size // tp_size`.
Returns:
- output
- bias
"""
# set up backprop all-reduce.
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
# TODO: simplify code below
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.apply(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.base_layer.skip_bias_add:
output = (output_ + self.base_layer.bias
if self.base_layer.bias is not None else output_)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
if not self.base_layer.return_bias:
return output
return output, output_bias
@classmethod @classmethod
@_not_fully_sharded_can_replace @_fully_sharded_can_replace
def can_replace_layer( def can_replace_layer(
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
...@@ -961,226 +482,129 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -961,226 +482,129 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list: list, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
return type(source_layer) is RowParallelLinear # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class LogitsProcessorWithLoRA(BaseLayerWithLoRA): class MergedColumnParallelLinearWithShardedLoRA(
""" MergedColumnParallelLinearWithLoRA):
LoRA wrapper for LogitsProcessor, with extra logic to handle the
application of the LoRA adapter and added LoRA vocabulary.
Args:
base_layer: LogitsProcessor layer
hidden_size: hidden size of the model
dtype: data type of the model
device: device of the model
sharded_to_full_mapping: index mapping from sharded vocab to full vocab
received from base_layer.get_sharded_to_full_mapping(). If None,
no reindexing will be done.
""" """
Differs from MergedColumnParallelLinearWithLoRA by slicing the
LoRA A's also.
def __init__(self, base_layer: LogitsProcessor, hidden_size: int, Based on S-LoRA, slicing happens along the rank dim.
dtype: torch.dtype, device: torch.device, """
sharded_to_full_mapping: Optional[list[int]]) -> None:
super().__init__()
self.base_layer = base_layer
self.hidden_size = hidden_size
self.dtype = dtype
self.device = device
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.sharded_to_full_mapping = sharded_to_full_mapping
@property
def logits_as_input(self):
return self.base_layer.logits_as_input
@property
def vocab_size(self):
return self.base_layer.vocab_size
@property
def scale(self):
return self.base_layer.scale
@property
def soft_cap(self):
return self.base_layer.soft_cap
@property
def use_all_gather(self):
return self.base_layer.use_all_gather
@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size
@property def slice_lora_a(
def include_gpu_probs_tensor(self): self, lora_a: list[Union[torch.Tensor, None]]
return self.base_layer.include_gpu_probs_tensor ) -> list[Union[torch.Tensor, None]]:
#NOTE: lora_a contains 2 subloras, and each sublora could be None.
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[0][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[0] is not None else None,
lora_a[1][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[1] is not None else None,
]
return lora_a
@property def apply(self,
def should_modify_greedy_probs_inplace(self): x: torch.Tensor,
return self.base_layer.should_modify_greedy_probs_inplace bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
def create_lora_weights( @classmethod
self, @_fully_sharded_can_replace
max_loras: int, def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None, packed_modules_list: list,
) -> None: model_config: Optional[PretrainedConfig],
# TODO: Verify if this condition can be further relaxed ) -> bool:
if 32000 < self.base_layer.vocab_size > 257024: # specifying kwargs so they can be easily accessed in decorator
raise ValueError("When using LoRA, vocab size must be " return super().can_replace_layer(
"32000 >= vocab_size <= 257024") source_layer=source_layer,
self.lora_a_stacked = torch.zeros( lora_config=lora_config,
( packed_modules_list=packed_modules_list,
max_loras, model_config=model_config,
1, decorate=False,
lora_config.max_lora_rank,
self.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
# Pad for kernel compatibility
math.ceil(self.base_layer.vocab_size /
lora_config.lora_vocab_padding_size) *
lora_config.lora_vocab_padding_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.embeddings_tensors = torch.full(
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
fill_value=float("-inf"),
dtype=self.dtype,
device=self.device,
) )
if self.sharded_to_full_mapping is not None:
self.sharded_to_full_mapping_gpu = torch.tensor(
self.sharded_to_full_mapping,
device=self.device,
dtype=torch.long)
else:
self.sharded_to_full_mapping_gpu = None
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = float("-inf")
def set_lora( class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
self, """
index: int, Differs from QKVParallelLinearWithLoRA by slicing the
lora_a: torch.Tensor, LoRA A's also.
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index,
:embeddings_tensor.shape[0],
:embeddings_tensor.shape[1],
] = embeddings_tensor
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.quant_method.apply(lm_head, hidden_states)
if embedding_bias is not None:
logits += embedding_bias
# Gather logits for TP
logits = self.base_layer._gather_logits(logits)
if logits is None:
return None
if self.sharded_to_full_mapping_gpu is not None:
# Reindex full logits tensor to ensure 1:1 mapping between
# index and token_id
# Example for:
# org_vocab_size = 4
# added_vocab_size = 2
# pad_to_size = 8
# tp_size = 2
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 4, -1, 2, 3, 5, -1]
# Therefore, the mapping is expected to be:
# [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
# we get:
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
logits = logits[:, self.sharded_to_full_mapping_gpu]
lora_logits = torch.empty(
self.embeddings_tensors.shape[0] + 1,
self.embeddings_tensors.shape[1],
hidden_states.shape[0],
dtype=self.embeddings_tensors.dtype,
device=self.embeddings_tensors.device,
)
torch.matmul(self.embeddings_tensors,
hidden_states.T,
out=lora_logits[:-1])
neg_inf, pos_inf = current_platform.get_infinity_values( Based on S-LoRA, slicing happens along the rank dim.
lora_logits.dtype) """
lora_logits[-1] = neg_inf def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
lora_logits = lora_logits.mT tp_rank = get_tensor_model_parallel_rank()
indices_padded = self.punica_wrapper.sampler_indices_padded shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a
if current_platform.is_tpu(): def apply(self,
indices_padded = indices_padded[:logits.size(0)] x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
lora_logits = (lora_logits.reshape( @classmethod
lora_logits.shape[0] * lora_logits.shape[1], @_fully_sharded_can_replace
lora_logits.shape[2], def can_replace_layer(cls, source_layer: nn.Module,
).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, lora_config: LoRAConfig, packed_modules_list: list,
posinf=pos_inf, model_config: Optional[PretrainedConfig]) -> bool:
neginf=neg_inf)) # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
logits[:,
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1]] = lora_logits
lora_output: Optional[ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
torch.Tensor] = self.punica_wrapper.add_lora_logits( """
logits, hidden_states, self.lora_a_stacked, Differs from MergedQKVParallelLinearWithLoRA by slicing the
self.lora_b_stacked, 1.0) LoRA A's also.
if not current_platform.can_update_inplace(): Based on S-LoRA, slicing happens along the rank dim.
logits = lora_output """
# Remove paddings in vocab (if any). def slice_lora_a(
logits = logits[:, :self.base_layer.vocab_size] self, lora_a: list[Union[torch.Tensor, None]]
return logits ) -> list[Union[torch.Tensor, None]]:
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[0][:, start_idx[0]:start_idx[0] +
shard_size[0]] if lora_a[0] is not None else None,
lora_a[1][:, start_idx[1]:start_idx[1] +
shard_size[1]] if lora_a[1] is not None else None,
lora_a[2][:, start_idx[2]:start_idx[2] +
shard_size[2]] if lora_a[2] is not None else None,
]
return lora_a
def forward(self, *args, **kwargs): def apply(self,
return type(self.base_layer).forward(self, *args, **kwargs) x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod @classmethod
@_fully_sharded_can_replace
def can_replace_layer( def can_replace_layer(
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
...@@ -1188,5 +612,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1188,5 +612,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
packed_modules_list: list, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
# Special handling for the LogitsProcessor. # specifying kwargs so they can be easily accessed in decorator
return False return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.platforms import current_platform
from .base import BaseLayerWithLoRA
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
"""
LoRA wrapper for LogitsProcessor, with extra logic to handle the
application of the LoRA adapter and added LoRA vocabulary.
Args:
base_layer: LogitsProcessor layer
hidden_size: hidden size of the model
dtype: data type of the model
device: device of the model
sharded_to_full_mapping: index mapping from sharded vocab to full vocab
received from base_layer.get_sharded_to_full_mapping(). If None,
no reindexing will be done.
"""
def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
dtype: torch.dtype, device: torch.device,
sharded_to_full_mapping: Optional[list[int]]) -> None:
super().__init__()
self.base_layer = base_layer
self.hidden_size = hidden_size
self.dtype = dtype
self.device = device
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.sharded_to_full_mapping = sharded_to_full_mapping
@property
def logits_as_input(self):
return self.base_layer.logits_as_input
@property
def vocab_size(self):
return self.base_layer.vocab_size
@property
def scale(self):
return self.base_layer.scale
@property
def soft_cap(self):
return self.base_layer.soft_cap
@property
def use_all_gather(self):
return self.base_layer.use_all_gather
@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size
@property
def include_gpu_probs_tensor(self):
return self.base_layer.include_gpu_probs_tensor
@property
def should_modify_greedy_probs_inplace(self):
return self.base_layer.should_modify_greedy_probs_inplace
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
# TODO: Verify if this condition can be further relaxed
if 32000 < self.base_layer.vocab_size > 257024:
raise ValueError("When using LoRA, vocab size must be "
"32000 >= vocab_size <= 257024")
self.lora_a_stacked = torch.zeros(
(
max_loras,
1,
lora_config.max_lora_rank,
self.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
# Pad for kernel compatibility
math.ceil(self.base_layer.vocab_size /
lora_config.lora_vocab_padding_size) *
lora_config.lora_vocab_padding_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.embeddings_tensors = torch.full(
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
fill_value=float("-inf"),
dtype=self.dtype,
device=self.device,
)
if self.sharded_to_full_mapping is not None:
self.sharded_to_full_mapping_gpu = torch.tensor(
self.sharded_to_full_mapping,
device=self.device,
dtype=torch.long)
else:
self.sharded_to_full_mapping_gpu = None
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = float("-inf")
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index,
:embeddings_tensor.shape[0],
:embeddings_tensor.shape[1],
] = embeddings_tensor
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.quant_method.apply(lm_head, hidden_states)
if embedding_bias is not None:
logits += embedding_bias
# Gather logits for TP
logits = self.base_layer._gather_logits(logits)
if logits is None:
return None
if self.sharded_to_full_mapping_gpu is not None:
# Reindex full logits tensor to ensure 1:1 mapping between
# index and token_id
# Example for:
# org_vocab_size = 4
# added_vocab_size = 2
# pad_to_size = 8
# tp_size = 2
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 4, -1, 2, 3, 5, -1]
# Therefore, the mapping is expected to be:
# [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
# we get:
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
logits = logits[:, self.sharded_to_full_mapping_gpu]
lora_logits = torch.empty(
self.embeddings_tensors.shape[0] + 1,
self.embeddings_tensors.shape[1],
hidden_states.shape[0],
dtype=self.embeddings_tensors.dtype,
device=self.embeddings_tensors.device,
)
torch.matmul(self.embeddings_tensors,
hidden_states.T,
out=lora_logits[:-1])
neg_inf, pos_inf = current_platform.get_infinity_values(
lora_logits.dtype)
lora_logits[-1] = neg_inf
lora_logits = lora_logits.mT
indices_padded = self.punica_wrapper.sampler_indices_padded
if current_platform.is_tpu() or current_platform.is_xpu():
indices_padded = indices_padded[:logits.size(0)]
lora_logits = (lora_logits.reshape(
lora_logits.shape[0] * lora_logits.shape[1],
lora_logits.shape[2],
).index_select(0, indices_padded).nan_to_num_(nan=neg_inf,
posinf=pos_inf,
neginf=neg_inf))
logits[:,
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1]] = lora_logits
lora_output: Optional[
torch.Tensor] = self.punica_wrapper.add_lora_logits(
logits, hidden_states, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
if not current_platform.can_update_inplace():
logits = lora_output
# Remove paddings in vocab (if any).
logits = logits[:, :self.base_layer.vocab_size]
return logits
def forward(self, *args, **kwargs):
return type(self.base_layer).forward(self, *args, **kwargs)
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
# Special handling for the LogitsProcessor.
return False
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .base import BaseLayerWithLoRA
#TODO: Implement this
class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA):
pass
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from .base_linear import BaseLinearLayerWithLoRA
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: ReplicatedLinear) -> None:
super().__init__(base_layer, )
# To ensure interface compatibility, set to 1 always.
self.tp_size = 1
self.output_size = self.base_layer.output_size
self.n_slices = 1
def forward(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Forward of ReplicatedLinearWithLoRA
Args:
input_: Tensor whose last dimension is `input_size`.
Returns:
- output
- bias
"""
bias = (self.base_layer.bias
if not self.base_layer.skip_bias_add else None)
# Matrix multiply.
output = self.apply(input_, bias)
output_bias = (self.base_layer.bias
if self.base_layer.skip_bias_add else None)
if not self.base_layer.return_bias:
return output
return output, output_bias
# ReplicatedLinear should always be replaced, regardless of the fully
# sharded LoRAs setting, because it is, by definition, copied per GPU.
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is ReplicatedLinear
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# pylint: disable=unused-argument from typing import Optional, Union, cast
from typing import TYPE_CHECKING, Optional, Union, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.distributed.communication_op import ( from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) get_tensor_model_parallel_world_size,
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank split_tensor_along_last_dim,
from vllm.lora.layers import (ColumnParallelLinearWithLoRA, tensor_model_parallel_all_reduce)
MergedColumnParallelLinearWithLoRA, # yapf: disable
MergedQKVParallelLinearWithLoRA, from vllm.model_executor.layers.linear import RowParallelLinear
QKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA)
from vllm.platforms import current_platform from vllm.platforms import current_platform
if TYPE_CHECKING: from .base_linear import BaseLinearLayerWithLoRA
pass from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace
def _fully_sharded_can_replace(can_replace): class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
"""
decorator which adds the condition of fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
return (can_replace(*args, **kwargs)
and kwargs["lora_config"].fully_sharded_loras)
return dec
def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA):
"""
For `ColumnParallelLinearWithLoRA` or classes that inherit from
`ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
"""
assert (layer.n_slices == len(layer.lora_a_stacked) == len(
layer.lora_b_stacked) == len(layer.output_slices))
if layer.lora_bias_stacked is not None:
assert layer.n_slices == len(layer.lora_bias_stacked)
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
# Since communication is needed, the buffer is directly initialized as a
# tensor rather than a tuple of tensor.
buffers = torch.zeros(
(layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink(
buffers, x, layer.lora_a_stacked, 1.0)
if not current_platform.can_update_inplace():
buffers = shrunk_buffers
buffers = tensor_model_parallel_all_gather(buffers)
lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand(
output,
buffers,
layer.lora_b_stacked,
layer.lora_bias_stacked,
layer.output_slices,
offset_start=0,
add_input=True)
if not current_platform.can_update_inplace():
output = lora_output
output = output.view(*out_orig_shape) def __init__(self, base_layer: RowParallelLinear) -> None:
# now have column partitioned and packed output super().__init__(base_layer)
return output
self.tp_size = get_tensor_model_parallel_world_size()
# reset input_size
self.input_size = self.base_layer.input_size_per_partition
self.output_size = self.base_layer.output_size
# these layers are based on the tensor parallelism strategy given in self.tp_rank = get_tensor_model_parallel_rank()
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, # There is only one LoRA layer.
# https://arxiv.org/abs/2311.03285. self.n_slices = 1
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
"""
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
Based on S-LoRA, slicing happens along the rank dim.
"""
# For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
# their `lora_a` and `lora_b` have different sharding patterns. After
# completing the `lora_a` GEMM , a gather operation is performed.
# Therefore, the sharding of `lora_a` only needs to correspond with the
# gather operation.
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
shard_size = self.input_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :]
return lora_a
class MergedColumnParallelLinearWithShardedLoRA( def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
MergedColumnParallelLinearWithLoRA): return lora_b
"""
Differs from MergedColumnParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a( def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
self, lora_a: list[Union[torch.Tensor, None]] return bias
) -> list[Union[torch.Tensor, None]]:
#NOTE: lora_a contains 2 subloras, and each sublora could be None.
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[0][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[0] is not None else None,
lora_a[1][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[1] is not None else None,
]
return lora_a
def apply(self, def forward(
x: torch.Tensor, self, input_: torch.Tensor
bias: Optional[torch.Tensor] = None) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
return _mcp_apply(x, bias, self) """Forward of RowParallelLinear
Args:
input_: tensor whose last dimension is `input_size`. If
`input_is_parallel` is set, then the last dimension
is `input_size // tp_size`.
Returns:
- output
- bias
"""
# set up backprop all-reduce.
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
# TODO: simplify code below
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.apply(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.base_layer.skip_bias_add:
output = (output_ + self.base_layer.bias
if self.base_layer.bias is not None else output_)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
if not self.base_layer.return_bias:
return output
return output, output_bias
@classmethod @classmethod
@_fully_sharded_can_replace @_not_fully_sharded_can_replace
def can_replace_layer( def can_replace_layer(
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
...@@ -169,98 +100,13 @@ class MergedColumnParallelLinearWithShardedLoRA( ...@@ -169,98 +100,13 @@ class MergedColumnParallelLinearWithShardedLoRA(
packed_modules_list: list, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
# specifying kwargs so they can be easily accessed in decorator return type(source_layer) is RowParallelLinear
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
"""
Differs from QKVParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: list,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
"""
Differs from MergedQKVParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(
self, lora_a: list[Union[torch.Tensor, None]]
) -> list[Union[torch.Tensor, None]]:
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[0][:, start_idx[0]:start_idx[0] +
shard_size[0]] if lora_a[0] is not None else None,
lora_a[1][:, start_idx[1]:start_idx[1] +
shard_size[1]] if lora_a[1] is not None else None,
lora_a[2][:, start_idx[2]:start_idx[2] +
shard_size[2]] if lora_a[2] is not None else None,
]
return lora_a
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
# The following layer is based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
""" """
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
import torch.nn as nn
from vllm.adapter_commons.layers import AdapterMapping
@dataclass
class LoRAMapping(AdapterMapping):
is_prefill: bool = False
def _get_lora_device(base_layer: nn.Module) -> torch.device:
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
"""Returns the device for where to place the LoRA tensors."""
# unquantizedLinear
if hasattr(base_layer, "weight"):
return base_layer.weight.device
# Compressed Tensor
elif hasattr(base_layer, "weight_packed"):
return base_layer.weight_packed.device
# GPTQ/AWQ
elif hasattr(base_layer, "qweight"):
return base_layer.qweight.device
# HQQ marlin
elif hasattr(base_layer, "W_q"):
return base_layer.W_q.device
else:
raise ValueError(f"Unsupported base layer: {base_layer}")
def _not_fully_sharded_can_replace(can_replace):
"""
decorator which adds the condition of not using fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
condition = (not kwargs["lora_config"].fully_sharded_loras
if decorate else True)
return can_replace(*args, **kwargs) and condition
return dec
def _fully_sharded_can_replace(can_replace):
"""
decorator which adds the condition of fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
return (can_replace(*args, **kwargs)
and kwargs["lora_config"].fully_sharded_loras)
return dec
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.platforms import current_platform
from .base import BaseLayerWithLoRA
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.embeddings_slice: Optional[tuple[int, int]]
self.embeddings_weights: Optional[torch.Tensor]
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
if self.base_layer.num_added_embeddings_per_partition > 0:
# We can start adding lora weights
self.embeddings_weights = self.base_layer.weight.data[
self.base_layer.num_org_embeddings_per_partition:self.
base_layer.num_org_embeddings_per_partition +
self.base_layer.num_added_embeddings_per_partition]
self.embeddings_slice = (
self.base_layer.shard_indices.added_vocab_start_index -
self.base_layer.org_vocab_size,
self.base_layer.shard_indices.added_vocab_end_index -
self.base_layer.org_vocab_size)
self.base_layer.weight.data[
self.base_layer.num_org_embeddings_per_partition:].fill_(0)
else:
self.embeddings_slice = None
self.embeddings_weights = None
self.embeddings_tensors = torch.zeros(
(
max_loras,
lora_config.lora_extra_vocab_size,
self.base_layer.embedding_dim,
),
dtype=self.base_layer.weight.dtype,
device=self.base_layer.weight.device,
)
self.lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.org_vocab_size +
lora_config.lora_extra_vocab_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
self.base_layer.embedding_dim,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_a_stacked_2d = self.lora_a_stacked.view(
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2],
)
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index,
:embeddings_tensor.shape[0],
:embeddings_tensor.shape[1],
].copy_(embeddings_tensor, non_blocking=True)
if self.embeddings_slice is not None:
# TODO(yard1): Optimize this copy, we don't need to copy
# everything, just the modified part
embeddings = self.embeddings_tensors.view(
self.embeddings_tensors.shape[0] *
self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2],
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
assert self.embeddings_weights is not None
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1,
1, 0)
# NB: Don't use torch.narrow here. torch.narrow triggers some
# Dynamic Shape specialization in torch.compile
num_tokens = x.shape[0]
indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens]
full_lora_a_embeddings = F.embedding(
x + indices_1,
self.lora_a_stacked_2d,
)
full_output = self.base_layer.forward(x +
(indices_0 * added_tokens_mask))
full_output_org = full_output
if full_output.ndim == 3:
full_output = full_output.view(
full_output.shape[0] * full_output.shape[1], -1)
if full_lora_a_embeddings.ndim == 3:
full_lora_a_embeddings = full_lora_a_embeddings.view(
full_lora_a_embeddings.shape[0] *
full_lora_a_embeddings.shape[1],
-1,
)
lora_output: Optional[
torch.Tensor] = self.punica_wrapper.add_lora_embedding(
full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)
if not current_platform.can_update_inplace():
full_output = lora_output
return full_output.view_as(full_output_org)
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is VocabParallelEmbedding
@property
def weight(self):
return self.base_layer.weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment