Commit b9e12416 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.3

parents e5d707db e9d3aa04
from typing import Any, Dict, List, Optional, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple, Union
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) make_async)
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
...@@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase): ...@@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase):
def _init_executor(self) -> None: def _init_executor(self) -> None:
"""Initialize the worker and load the model. """Initialize the worker and load the model.
If speculative decoding is enabled, we instead create the speculative
worker.
""" """
if self.speculative_config is None: assert self.parallel_config.world_size == 1, (
self._init_non_spec_worker() "GPUExecutor only supports single GPU.")
else:
self._init_spec_worker() self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _get_worker_kwargs( def _get_worker_kwargs(
self, self,
...@@ -45,6 +44,7 @@ class GPUExecutor(ExecutorBase): ...@@ -45,6 +44,7 @@ class GPUExecutor(ExecutorBase):
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, vision_language_config=self.vision_language_config,
speculative_config=self.speculative_config,
is_driver_worker=rank == 0, is_driver_worker=rank == 0,
) )
...@@ -52,53 +52,22 @@ class GPUExecutor(ExecutorBase): ...@@ -52,53 +52,22 @@ class GPUExecutor(ExecutorBase):
local_rank: int = 0, local_rank: int = 0,
rank: int = 0, rank: int = 0,
distributed_init_method: Optional[str] = None): distributed_init_method: Optional[str] = None):
if self.speculative_config is None:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
else:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
wrapper = WorkerWrapperBase( wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.worker", worker_module_name=worker_module_name,
worker_class_name="Worker", worker_class_name=worker_class_name,
) )
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method)) distributed_init_method))
return wrapper.worker return wrapper.worker
def _init_non_spec_worker(self):
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _init_spec_worker(self):
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
"""
assert self.speculative_config is not None
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
target_worker = self._create_worker()
draft_worker_kwargs = self._get_worker_kwargs()
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=self.speculative_config.draft_model_config,
parallel_config=self.speculative_config.draft_parallel_config,
# TODO allow draft-model specific load config.
#load_config=self.load_config,
)
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
)
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = spec_decode_worker
# Load model handled in spec decode worker.
self.driver_worker.init_device()
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the """Determine the number of available KV blocks by invoking the
underlying worker. underlying worker.
...@@ -117,8 +86,8 @@ class GPUExecutor(ExecutorBase): ...@@ -117,8 +86,8 @@ class GPUExecutor(ExecutorBase):
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model( def execute_model(
self, self, execute_model_req: ExecuteModelRequest
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: ) -> List[Union[SamplerOutput, PoolerOutput]]:
output = self.driver_worker.execute_model(execute_model_req) output = self.driver_worker.execute_model(execute_model_req)
return output return output
...@@ -144,7 +113,7 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): ...@@ -144,7 +113,7 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
async def execute_model_async( async def execute_model_async(
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]: ) -> List[Union[SamplerOutput, PoolerOutput]]:
output = await make_async(self.driver_worker.execute_model output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, ) )(execute_model_req=execute_model_req, )
return output return output
import asyncio
import os
from functools import partial
from typing import Any, List, Optional
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
logger = init_logger(__name__)
class MultiprocessingGPUExecutor(DistributedGPUExecutor):
"""Python multiprocessing-based multi-GPU executor"""
def _init_executor(self) -> None:
assert (
not self.speculative_config
), "Speculative decoding not yet supported for MultiProcGPU backend."
# Create the parallel GPU workers.
world_size = self.parallel_config.tensor_parallel_size
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = (",".join(
map(str, range(world_size))))
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
from torch.cuda import device_count
assert world_size <= device_count(), (
"please set tensor_parallel_size to less than max local gpu count")
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
if world_size == 1:
self.workers = []
else:
result_handler = ResultHandler()
self.workers = [
ProcessWorkerWrapper(
result_handler,
partial(
self._create_worker,
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
)) for rank in range(1, world_size)
]
self.worker_monitor = WorkerMonitor(self.workers, result_handler)
result_handler.start()
self.worker_monitor.start()
self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor",
None)) is not None:
worker_monitor.close()
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_model(
execute_model_req=execute_model_req)
def _run_workers(
self,
method: str,
*args,
async_run_remote_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
"""
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
# Start the workers first.
worker_outputs = [
worker.execute_method(method, *args, **kwargs)
for worker in self.workers
]
if async_run_remote_workers_only:
# Just return futures
return worker_outputs
driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*args, **kwargs)
# Get the results of the workers.
return [driver_worker_output
] + [output.get() for output in worker_outputs]
def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
if not self.worker_monitor.is_alive():
raise RuntimeError("Worker processes are not running")
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
for result in parallel_worker_tasks:
result.get()
class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_model(execute_model_req)
async def _start_worker_execution_loop(self):
coros = [
worker.execute_method_async("start_worker_execution_loop")
for worker in self.workers
]
return await asyncio.gather(*coros)
...@@ -28,10 +28,7 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG ...@@ -28,10 +28,7 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayGPUExecutor(DistributedGPUExecutor): class RayGPUExecutor(DistributedGPUExecutor):
def _init_executor(self) -> None: def _init_executor(self) -> None:
assert (not self.speculative_config assert self.parallel_config.distributed_executor_backend == "ray"
), "Speculative decoding not yet supported for RayGPU backend."
assert self.parallel_config.worker_use_ray
placement_group = self.parallel_config.placement_group placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection. # Disable Ray usage stats collection.
...@@ -45,6 +42,8 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -45,6 +42,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
self.forward_dag = None self.forward_dag = None
if USE_RAY_COMPILED_DAG: if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag() self.forward_dag = self._compiled_ray_dag()
self.extra_execute_model_run_workers_kwargs[
"use_ray_compiled_dag"] = True
def _configure_ray_workers_use_nsight(self, def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]: ray_remote_kwargs) -> Dict[str, Any]:
...@@ -90,14 +89,22 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -90,14 +89,22 @@ class RayGPUExecutor(DistributedGPUExecutor):
placement_group_capture_child_tasks=True, placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id, placement_group_bundle_index=bundle_id,
) )
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
worker = ray.remote( worker = ray.remote(
num_cpus=0, num_cpus=0,
num_gpus=num_gpus, num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy, scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorkerWrapper).remote( )(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.worker", worker_module_name=worker_module_name,
worker_class_name="Worker", worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
...@@ -107,8 +114,8 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -107,8 +114,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
# as the resource holder for the driver process. # as the resource holder for the driver process.
self.driver_dummy_worker = worker self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper( self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.worker", worker_module_name=worker_module_name,
worker_class_name="Worker", worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
else: else:
...@@ -166,23 +173,23 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -166,23 +173,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers=self.parallel_config. max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers) max_parallel_loading_workers)
def execute_model( def _driver_execute_model(
self, self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: execute_model_req: Optional[ExecuteModelRequest] = None
all_outputs = self._run_workers( ) -> List[SamplerOutput]:
"execute_model", """Run execute_model in the driver worker.
driver_kwargs={"execute_model_req": execute_model_req},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
# Only the driver worker returns the sampling results. Passing None will cause the driver to stop the model execution
return all_outputs[0] loop running in each of the remote workers.
"""
return self.driver_worker.execute_method("execute_model",
execute_model_req)
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
*args, *args,
driver_args: Optional[Tuple[Any, ...]] = None, async_run_remote_workers_only: bool = False,
driver_kwargs: Optional[Dict[str, Any]] = None,
all_args: Optional[List[Tuple[Any, ...]]] = None, all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False, use_dummy_driver: bool = False,
...@@ -193,9 +200,11 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -193,9 +200,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
"""Runs the given method on all workers. Can be used in the following """Runs the given method on all workers. Can be used in the following
ways: ways:
- async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than blocking
on the results.
- args/kwargs: All workers share the same args/kwargs - args/kwargs: All workers share the same args/kwargs
- args/kwargs and driver_args/driver_kwargs: Driver worker has
different args
- all_args/all_kwargs: args/kwargs for each worker are specified - all_args/all_kwargs: args/kwargs for each worker are specified
individually individually
""" """
...@@ -204,11 +213,6 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -204,11 +213,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
raise NotImplementedError( raise NotImplementedError(
"max_concurrent_workers is not supported yet.") "max_concurrent_workers is not supported yet.")
if driver_args is None:
driver_args = args if all_args is None else all_args[0]
if driver_kwargs is None:
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
count = len(self.workers) count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \ all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None) else islice(all_args, 1, None)
...@@ -220,6 +224,7 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -220,6 +224,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# input. TODO(sang): Fix it. # input. TODO(sang): Fix it.
assert self.forward_dag is not None assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1) output_channels = self.forward_dag.execute(1)
ray_worker_outputs = []
else: else:
# Start the ray workers first. # Start the ray workers first.
ray_worker_outputs = [ ray_worker_outputs = [
...@@ -229,6 +234,13 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -229,6 +234,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
) in zip(self.workers, all_worker_args, all_worker_kwargs) ) in zip(self.workers, all_worker_args, all_worker_kwargs)
] ]
if async_run_remote_workers_only:
# Just return futures
return ray_worker_outputs
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers. # Start the driver worker after all the ray workers.
if not use_dummy_driver: if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method( driver_worker_output = self.driver_worker.execute_method(
...@@ -255,6 +267,11 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -255,6 +267,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
return [driver_worker_output] + ray_worker_outputs return [driver_worker_output] + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)
def _compiled_ray_dag(self): def _compiled_ray_dag(self):
import pkg_resources import pkg_resources
required_version = "2.9" required_version = "2.9"
...@@ -264,7 +281,7 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -264,7 +281,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
f"required, but found {current_version}") f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.worker_use_ray assert self.parallel_config.distributed_executor_backend == "ray"
# Right now, compiled DAG requires at least 1 arg. We send # Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon. # a dummy value for now. It will be fixed soon.
...@@ -298,30 +315,18 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): ...@@ -298,30 +315,18 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.driver_executor = make_async(self.driver_worker.execute_method) self.driver_exec_method = make_async(self.driver_worker.execute_method)
async def _run_workers_async( async def _driver_execute_model_async(
self, self,
method: str, execute_model_req: Optional[ExecuteModelRequest] = None
*args, ) -> List[SamplerOutput]:
driver_args: Optional[Tuple[Any, ...]] = None, return await self.driver_exec_method("execute_model",
driver_kwargs: Optional[Dict[str, Any]] = None, execute_model_req)
**kwargs,
) -> Any: async def _start_worker_execution_loop(self):
"""Runs the given method on all workers.""" coros = [
coros = [] worker.execute_method.remote("start_worker_execution_loop")
for worker in self.workers
if driver_args is None: ]
driver_args = args return await asyncio.gather(*coros)
if driver_kwargs is None:
driver_kwargs = kwargs
coros.append(
self.driver_executor(method, *driver_args, **driver_kwargs))
# Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))
all_outputs = await asyncio.gather(*coros)
return all_outputs
...@@ -44,7 +44,7 @@ try: ...@@ -44,7 +44,7 @@ try:
except ImportError as e: except ImportError as e:
logger.warning( logger.warning(
"Failed to import Ray with %r. For distributed inference, " "Failed to import Ray with %r. For multi-node inference, "
"please install Ray with `pip install ray`.", e) "please install Ray with `pip install ray`.", e)
ray = None # type: ignore ray = None # type: ignore
RayWorkerWrapper = None # type: ignore RayWorkerWrapper = None # type: ignore
...@@ -67,7 +67,7 @@ def initialize_ray_cluster( ...@@ -67,7 +67,7 @@ def initialize_ray_cluster(
""" """
if ray is None: if ray is None:
raise ImportError( raise ImportError(
"Ray is not installed. Please install Ray to use distributed " "Ray is not installed. Please install Ray to use multi-node "
"serving.") "serving.")
# Connect to a ray cluster. # Connect to a ray cluster.
......
from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
TypedDict, Union, cast, overload)
from typing_extensions import NotRequired
if TYPE_CHECKING:
from vllm.sequence import MultiModalData
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
# https://github.com/vllm-project/vllm/pull/4028
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0], str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False)
for elem in cast(List[str], prompt)
]
if isinstance(prompt[0], int):
# case 3: array of tokens
elem = cast(List[int], prompt)
return [ParsedTokens(content=elem, is_tokens=True)]
if isinstance(prompt[0], list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0][0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in cast(List[List[int]], prompt)
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
class TextPrompt(TypedDict):
"""Schema for a text prompt."""
prompt: str
"""The input text to be tokenized before passing to the model."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class TokensPrompt(TypedDict):
"""Schema for a tokenized prompt."""
prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class TextTokensPrompt(TypedDict):
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt: str
"""The prompt text."""
prompt_token_ids: List[int]
"""The token IDs of the prompt. If None, we use the
tokenizer to convert the prompts to token IDs."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
"""
The inputs to the LLM, which can take one of the following forms:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
"""
PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
"""Same as :const:`PromptStrictInputs` but additionally accepts
:class:`TextTokensPrompt`."""
class LLMInputs(TypedDict):
prompt_token_ids: List[int]
prompt: NotRequired[Optional[str]]
multi_modal_data: NotRequired[Optional["MultiModalData"]]
...@@ -14,6 +14,7 @@ import vllm.envs as envs ...@@ -14,6 +14,7 @@ import vllm.envs as envs
VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S" _DATE_FORMAT = "%m-%d %H:%M:%S"
...@@ -30,7 +31,7 @@ DEFAULT_LOGGING_CONFIG = { ...@@ -30,7 +31,7 @@ DEFAULT_LOGGING_CONFIG = {
"vllm": { "vllm": {
"class": "logging.StreamHandler", "class": "logging.StreamHandler",
"formatter": "vllm", "formatter": "vllm",
"level": "INFO", "level": VLLM_LOGGING_LEVEL,
"stream": "ext://sys.stdout", "stream": "ext://sys.stdout",
}, },
}, },
......
# pylint: disable=unused-argument # pylint: disable=unused-argument
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -51,10 +51,9 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): ...@@ -51,10 +51,9 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
lora_a = lora_a[:, start_idx:start_idx + shard_size] lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a return lora_a
def apply_weights(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
self.base_layer, x, bias)
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output, out_orig_shape = output.view(-1,
...@@ -88,7 +87,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): ...@@ -88,7 +87,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
) )
def _mcp_apply_weights(x, bias, layer): def _mcp_apply(x, bias, layer):
""" """
MergedColumnParallelLinearWithShardedLoRA and MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same QKVParallelLinearWithShardedLora share the same
...@@ -100,8 +99,7 @@ def _mcp_apply_weights(x, bias, layer): ...@@ -100,8 +99,7 @@ def _mcp_apply_weights(x, bias, layer):
""" """
# expecting 2 for column parallel and 3 for qkv # expecting 2 for column parallel and 3 for qkv
n = len(layer.lora_a_stacked) n = len(layer.lora_a_stacked)
output = layer.base_layer.linear_method.apply_weights( output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
layer.base_layer, x, bias)
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
...@@ -136,18 +134,23 @@ class MergedColumnParallelLinearWithShardedLoRA( ...@@ -136,18 +134,23 @@ class MergedColumnParallelLinearWithShardedLoRA(
Based on S-LoRA, slicing happens along the rank dim. Based on S-LoRA, slicing happens along the rank dim.
""" """
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_a[0] is None or lora_a[1] is None:
return lora_a
output_shard_size = self.lora_a_stacked[0].shape[2] output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size output_start_idx = self.tp_rank * output_shard_size
lora_a = [ lora_a = [
lora_a[i][:, output_start_idx:output_start_idx + output_shard_size] lora_a[0][:,
for i in range(2) output_start_idx:output_start_idx + output_shard_size],
lora_a[1][:, output_start_idx:output_start_idx + output_shard_size]
] ]
return lora_a return lora_a
def apply_weights(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self) return _mcp_apply(x, bias, self)
@classmethod @classmethod
@_fully_sharded_can_replace @_fully_sharded_can_replace
...@@ -172,19 +175,23 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): ...@@ -172,19 +175,23 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
Based on S-LoRA, slicing happens along the rank dim. Based on S-LoRA, slicing happens along the rank dim.
""" """
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None:
return lora_a
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)] start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [ lora_a = [
lora_a[i][:, start_idx[i]:start_idx[i] + lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]],
shard_size[i]] if lora_a[i] is not None else None lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]],
for i in range(3) lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]]
] ]
return lora_a return lora_a
def apply_weights(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self) return _mcp_apply(x, bias, self)
@classmethod @classmethod
@_fully_sharded_can_replace @_fully_sharded_can_replace
...@@ -218,9 +225,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): ...@@ -218,9 +225,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
lora_b = lora_b[:, start_idx:end_idx] lora_b = lora_b[:, start_idx:end_idx]
return lora_b return lora_b
def apply_weights(self, x: torch.Tensor) -> torch.Tensor: def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.quant_method.apply(self.base_layer, x)
self.base_layer, x)
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output, out_orig_shape = output.view(-1,
......
# pylint: disable=unused-argument # pylint: disable=unused-argument
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -22,6 +22,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -22,6 +22,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import (
LinearScalingRotaryEmbedding, RotaryEmbedding)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -145,11 +147,15 @@ class LoRAMapping: ...@@ -145,11 +147,15 @@ class LoRAMapping:
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: def slice_lora_a(
self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
"""Slice lora a if splitting for tensor parallelism.""" """Slice lora a if splitting for tensor parallelism."""
... ...
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def slice_lora_b(
self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
"""Slice lora b if splitting with tensor parallelism.""" """Slice lora b if splitting with tensor parallelism."""
... ...
...@@ -181,6 +187,7 @@ class BaseLayerWithLoRA(nn.Module): ...@@ -181,6 +187,7 @@ class BaseLayerWithLoRA(nn.Module):
sampler_indices: torch.Tensor, sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor, sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor, embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int], indices_len: List[int],
): ):
"""Sets the mapping indices.""" """Sets the mapping indices."""
...@@ -302,6 +309,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -302,6 +309,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
sampler_indices: torch.Tensor, sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor, sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor, embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int], indices_len: List[int],
): ):
self.indices = base_indices self.indices = base_indices
...@@ -427,6 +435,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -427,6 +435,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
sampler_indices: torch.Tensor, sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor, sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor, embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int], indices_len: List[int],
): ):
self.indices = base_indices self.indices = base_indices
...@@ -539,10 +548,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -539,10 +548,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.lora_b_stacked[0][index] = 0 self.lora_b_stacked[0][index] = 0
self.lora_b_stacked[1][index] = 0 self.lora_b_stacked[1][index] = 0
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
return lora_a return lora_a
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_b[0] is None or lora_b[1] is None:
return lora_b
shard_size = self.output_dim shard_size = self.output_dim
start_idx = self.tp_rank * shard_size start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size end_idx = (self.tp_rank + 1) * shard_size
...@@ -767,10 +782,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -767,10 +782,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.lora_a_stacked[2][index] = 0 self.lora_a_stacked[2][index] = 0
self.lora_b_stacked[2][index] = 0 self.lora_b_stacked[2][index] = 0
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
return lora_a return lora_a
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
lora_b_q, lora_b_k, lora_b_v = None, None, None
if lora_b[0] is not None: if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size * lora_b_q = lora_b[0][:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size *
...@@ -936,6 +956,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -936,6 +956,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
sampler_indices: torch.Tensor, sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor, sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor, embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int], indices_len: List[int],
): ):
self.indices = base_indices self.indices = base_indices
...@@ -992,7 +1013,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -992,7 +1013,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
@property @property
def weight(self): def weight(self):
return self.base_layer.weight if hasattr( return self.base_layer.weight if hasattr(
self.base_layer, "weight") else self.base_layer.qweight self.base_layer, "weight") else self.base_layer.qweight
...@@ -1113,6 +1133,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1113,6 +1133,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
sampler_indices: torch.Tensor, sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor, sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor, embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int], indices_len: List[int],
): ):
self.indices = sampler_indices self.indices = sampler_indices
...@@ -1179,3 +1200,101 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1179,3 +1200,101 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config: Optional[PretrainedConfig]) -> bool: model_config: Optional[PretrainedConfig]) -> bool:
# Special handling for the LogitsProcessor. # Special handling for the LogitsProcessor.
return False return False
class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
"""Implements RoPE-scaled embeddings with linear scaling for
multiple LoRA adapters with a specialized kernel.
Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
which can handle multi lora adapters in a specialied kernel.
"""
def __init__(self, base_layer: RotaryEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
# Lazily initialized
self.long_lora_indices: torch.Tensor
self.indices_len: List[int]
@property
def scaling_factors(self):
return self.base_layer.scaling_factors
@property
def rotary_dim(self):
return self.base_layer.rotary_dim
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
scaling_factors = list(
lora_config.long_lora_scaling_factors
) if lora_config.long_lora_scaling_factors else []
base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
scaling_factors = sorted(
list(set([base_scaling_factor] + scaling_factors)))
self.base_layer = LinearScalingRotaryEmbedding(
self.base_layer.head_size,
self.base_layer.rotary_dim,
self.base_layer.max_position_embeddings,
self.base_layer.base,
self.base_layer.is_neox_style,
scaling_factors,
self.base_layer.dtype,
)
def reset_lora(self, index: int):
...
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
...
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.long_lora_indices = long_lora_indices
self.indices_len = indices_len
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
return self.base_layer(
positions,
query,
key,
offsets=self.long_lora_indices[:self.indices_len[4]])
@property
def scaling_factor_to_offset(self) -> Dict[float, int]:
return self.base_layer.scaling_factor_to_offset
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
return type(source_layer) is LinearScalingRotaryEmbedding or type(
source_layer) is RotaryEmbedding
def extra_repr(self) -> str:
return self.base_layer.extra_repr()
...@@ -3,7 +3,8 @@ import json ...@@ -3,7 +3,8 @@ import json
import math import math
import os import os
import re import re
from typing import Callable, Dict, List, Optional, Tuple, Type from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import safetensors.torch import safetensors.torch
import torch import torch
...@@ -11,7 +12,9 @@ from torch import nn ...@@ -11,7 +12,9 @@ from torch import nn
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping from vllm.lora.layers import (BaseLayerWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import (from_layer, from_layer_logits_processor, from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule) parse_fine_tuned_lora_name, replace_submodule)
...@@ -22,10 +25,27 @@ logger = init_logger(__name__) ...@@ -22,10 +25,27 @@ logger = init_logger(__name__)
_GLOBAL_LORA_ID = 0 _GLOBAL_LORA_ID = 0
@dataclass
class LongContextLoRAContext:
"""Context for lora adapters that support long context."""
# The scaling factors to support long context lora fine tuned models.
scaling_factors: List[float]
# dimension to apply rotary embedding.
rot_dim: int
# offsets to the sin_cos_cache for each lora_id loaded.
# This value is dynamically modified.
offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
def convert_mapping( def convert_mapping(
mapping: LoRAMapping, lora_index_to_id: List[Optional[int]], mapping: LoRAMapping,
max_loras: int, vocab_size: int, extra_vocab_size: int lora_index_to_id: List[Optional[int]],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional[LongContextLoRAContext] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], List[int]]:
"""Converts LoRAMapping to index tensors. """Converts LoRAMapping to index tensors.
Args: Args:
...@@ -34,6 +54,7 @@ def convert_mapping( ...@@ -34,6 +54,7 @@ def convert_mapping(
max_loras: Maximum number of LoRAs. max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size. vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have. extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns: Returns:
A tuple of tensors: A tuple of tensors:
...@@ -51,11 +72,23 @@ def convert_mapping( ...@@ -51,11 +72,23 @@ def convert_mapping(
requests to embedding indices. First row is for embeddings requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a added by the LoRAs, second row is for the LoRA.lora_a
embeddings. embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors. indices_len: List of lengths of the above tensors.
Used to index into each tensor. It contains length for
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices). If long_lora doesn't
exist, it only contains first 4 entries.
""" """
index_mapping_indices: List[int] = list(mapping.index_mapping).copy() index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy() embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device="cuda",
dtype=torch.long)
prompt_mapping: List[int] = [ prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1 lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping for x in mapping.prompt_mapping
...@@ -66,13 +99,20 @@ def convert_mapping( ...@@ -66,13 +99,20 @@ def convert_mapping(
lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1) if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
index_mapping_indices[i] = i
lora_indices[i] = lora_idx lora_indices[i] = lora_idx
if long_lora_context:
indices = torch.tensor( assert long_lora_offsets is not None
[index_mapping_indices, lora_indices, embedding_indices], lora_offset: int = long_lora_context.offsets_by_lora_id.get(
dtype=torch.long, index_mapping_indices[i], 0)
device="cuda") long_lora_offsets[i] = lora_offset
indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices, lora_indices, embedding_indices
]
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
prompt_mapping_tensor = torch.tensor(prompt_mapping, prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda", device="cuda",
dtype=torch.long) dtype=torch.long)
...@@ -89,13 +129,21 @@ def convert_mapping( ...@@ -89,13 +129,21 @@ def convert_mapping(
torch.arange( torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded))) (sampler_indices_padded * len(sampler_indices_padded)))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
long_lora_indices = indices[3]
long_lora_indices_len = long_lora_indices.shape[-1]
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [ indices_len = [
base_indices.shape[-1], sampler_indices.shape[-1], base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1] sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
] ]
if long_lora_indices_len is not None:
indices_len.append(long_lora_indices_len)
return (base_indices, sampler_indices, sampler_indices_padded, return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, indices_len) embeddings_indices, long_lora_indices, indices_len)
def get_lora_id(): def get_lora_id():
...@@ -112,13 +160,35 @@ class LoRAModel: ...@@ -112,13 +160,35 @@ class LoRAModel:
lora_model_id: int, lora_model_id: int,
rank: int, rank: int,
loras: Dict[str, LoRALayerWeights], loras: Dict[str, LoRALayerWeights],
scaling_factor: Optional[float] = None,
) -> None: ) -> None:
"""
Args:
lora_model_id: The integer id for the lora model.
rank: lora rank.
loras: module name -> weights for lora-replaced layers.
scaling_factor: Scaling factor to support long context lora model.
None if the lora is not tuned for long context support.
"""
self.id = lora_model_id self.id = lora_model_id
# Scaling factor for long context lora model. None if it is not
# fine tuned for the long context.
self.scaling_factor = scaling_factor
assert (lora_model_id > assert (lora_model_id >
0), f"a valid lora id should be greater than 0, got {self.id}" 0), f"a valid lora id should be greater than 0, got {self.id}"
self.rank = rank self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras self.loras: Dict[str, LoRALayerWeights] = loras
def clone(self, lora_model_id: int) -> "LoRAModel":
"""Return a copy of the object with different ids.
Will share the underlying tensors."""
return self.__class__(
lora_model_id,
rank=self.rank,
loras=self.loras.copy(),
)
@property @property
def extra_vocab_size(self) -> int: def extra_vocab_size(self) -> int:
return max(lora.extra_vocab_size return max(lora.extra_vocab_size
...@@ -140,6 +210,7 @@ class LoRAModel: ...@@ -140,6 +210,7 @@ class LoRAModel:
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
embeddings: Optional[Dict[str, torch.Tensor]] = None, embeddings: Optional[Dict[str, torch.Tensor]] = None,
target_embedding_padding: Optional[int] = None, target_embedding_padding: Optional[int] = None,
scaling_factor: Optional[float] = None,
embedding_modules: Optional[Dict[str, str]] = None, embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None, embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel": ) -> "LoRAModel":
...@@ -189,13 +260,15 @@ class LoRAModel: ...@@ -189,13 +260,15 @@ class LoRAModel:
for lora in loras.values(): for lora in loras.values():
lora.optimize() lora.optimize()
return cls(lora_model_id, rank, loras) return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
@classmethod @classmethod
def from_local_checkpoint( def from_local_checkpoint(
cls, cls,
lora_dir: str, lora_dir: str,
expected_lora_modules: List[str], expected_lora_modules: List[str],
*,
max_position_embeddings: Optional[int] = None,
lora_model_id: Optional[int] = None, lora_model_id: Optional[int] = None,
device: str = "cuda", device: str = "cuda",
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
...@@ -203,7 +276,23 @@ class LoRAModel: ...@@ -203,7 +276,23 @@ class LoRAModel:
embedding_modules: Optional[Dict[str, str]] = None, embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None, embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel": ) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.""" """Create a LoRAModel from a local checkpoint.
Args:
lora_dir: The local path that has lora data.
expected_lora_modules: Name of modules that are expected to be
replaced by lora.
max_position_embeddings: Max position embedding length. Used to
scaling the largest context length. If None, the lora model's
context length is not scaled.
lora_model_id: Lora model id. If not given, automatically set by
a global counter.
device: Device where the lora model is loaded.
dtype: dtype of the lora model weights.
Returns:
Loaded LoRA Model.
"""
lora_config_path = os.path.join(lora_dir, "adapter_config.json") lora_config_path = os.path.join(lora_dir, "adapter_config.json")
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
...@@ -221,7 +310,9 @@ class LoRAModel: ...@@ -221,7 +310,9 @@ class LoRAModel:
if part_name not in expected_lora_modules: if part_name not in expected_lora_modules:
unexpected_modules.append(module) unexpected_modules.append(module)
# loaded lora's target modules must be a subset of expected_lora_modules # loaded lora's target modules must be a subset of expected_lora_modules
if unexpected_modules: if unexpected_modules:
print(unexpected_modules, "modules")
raise ValueError( raise ValueError(
f"While loading {lora_dir}, expected" f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}" f" target modules in {expected_lora_modules}"
...@@ -243,6 +334,14 @@ class LoRAModel: ...@@ -243,6 +334,14 @@ class LoRAModel:
rank = config["r"] rank = config["r"]
lora_alpha = config["lora_alpha"] lora_alpha = config["lora_alpha"]
context_length = config.get("context_length", None)
scaling_factor = None
if context_length:
if max_position_embeddings is None:
max_position_embeddings = context_length
scaling_factor = float(
math.ceil(context_length / max_position_embeddings))
return cls.from_lora_tensors( return cls.from_lora_tensors(
lora_model_id=get_lora_id() lora_model_id=get_lora_id()
if lora_model_id is None else lora_model_id, if lora_model_id is None else lora_model_id,
...@@ -253,6 +352,7 @@ class LoRAModel: ...@@ -253,6 +352,7 @@ class LoRAModel:
dtype=dtype, dtype=dtype,
embeddings=embeddings, embeddings=embeddings,
target_embedding_padding=target_embedding_padding, target_embedding_padding=target_embedding_padding,
scaling_factor=scaling_factor,
embedding_modules=embedding_modules, embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules, embedding_padding_modules=embedding_padding_modules,
) )
...@@ -286,6 +386,7 @@ class LoRAModelManager: ...@@ -286,6 +386,7 @@ class LoRAModelManager:
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None
self.base_indices = torch.empty(self.max_num_batched_tokens, self.base_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long, dtype=torch.long,
device="cuda") device="cuda")
...@@ -299,6 +400,12 @@ class LoRAModelManager: ...@@ -299,6 +400,12 @@ class LoRAModelManager:
self.max_num_batched_tokens, self.max_num_batched_tokens,
dtype=torch.long, dtype=torch.long,
device="cuda") device="cuda")
self.long_lora_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {}
# 4 is the number of indicies tensors defined above # 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded, # base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices # embeddings_indices
...@@ -308,6 +415,10 @@ class LoRAModelManager: ...@@ -308,6 +415,10 @@ class LoRAModelManager:
if hasattr(self.model, "supported_lora_modules"): if hasattr(self.model, "supported_lora_modules"):
self.supported_lora_modules = copy.deepcopy( self.supported_lora_modules = copy.deepcopy(
self.model.supported_lora_modules) self.model.supported_lora_modules)
if lora_config.long_lora_scaling_factors:
# We need to replace rotary emb layer to do batch computation
# for long lora.
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = copy.deepcopy( self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping) self.model.packed_modules_mapping)
self.packed_modules: Dict[str, List[str]] = {} self.packed_modules: Dict[str, List[str]] = {}
...@@ -373,12 +484,32 @@ class LoRAModelManager: ...@@ -373,12 +484,32 @@ class LoRAModelManager:
return True return True
return False return False
def _set_long_lora_context(self, lora: LoRAModel):
if self.long_lora_context is None:
return
if lora.scaling_factor is None:
return
if (lora.scaling_factor not in self.scaling_factor_to_offset):
raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
" has not been initialized.")
offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
if offsets:
self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
def _add_lora(self, lora: LoRAModel): def _add_lora(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora) self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora self._registered_loras[lora.id] = lora
self._set_long_lora_context(lora)
def add_lora(self, lora: LoRAModel) -> bool: def add_lora(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager CPU cache.""" """Add a LoRAModel to the manager CPU cache."""
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
if lora.id not in self._registered_loras: if lora.id not in self._registered_loras:
if len(self._registered_loras) >= self.capacity: if len(self._registered_loras) >= self.capacity:
raise RuntimeError("No free LoRA slots.") raise RuntimeError("No free LoRA slots.")
...@@ -390,15 +521,18 @@ class LoRAModelManager: ...@@ -390,15 +521,18 @@ class LoRAModelManager:
"""Remove a LoRAModel from the manager CPU cache.""" """Remove a LoRAModel from the manager CPU cache."""
# TODO: should we check active lora? # TODO: should we check active lora?
self.deactivate_lora(lora_id) self.deactivate_lora(lora_id)
if self.long_lora_context:
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
return bool(self._registered_loras.pop(lora_id, None)) return bool(self._registered_loras.pop(lora_id, None))
# TODO see if this can be vectorized # TODO see if this can be vectorized
def _set_lora_mapping(self, mapping: LoRAMapping) -> None: def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded, (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, embeddings_indices, long_lora_offsets_tensor,
indices_len) = convert_mapping(mapping, self.lora_index_to_id, indices_len) = convert_mapping(mapping, self.lora_index_to_id,
self.lora_slots + 1, self.vocab_size, self.lora_slots + 1, self.vocab_size,
self.lora_config.lora_extra_vocab_size) self.lora_config.lora_extra_vocab_size,
self.long_lora_context)
self.base_indices[:base_indices.shape[0]].copy_(base_indices) self.base_indices[:base_indices.shape[0]].copy_(base_indices)
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
...@@ -406,6 +540,11 @@ class LoRAModelManager: ...@@ -406,6 +540,11 @@ class LoRAModelManager:
self.embeddings_indices[:embeddings_indices. self.embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_( shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices) embeddings_indices)
if long_lora_offsets_tensor is not None:
self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
long_lora_offsets_tensor)
else:
self.long_lora_indices.zero_()
# Maintain the reference # Maintain the reference
self.indices_len[:] = indices_len self.indices_len[:] = indices_len
...@@ -428,7 +567,8 @@ class LoRAModelManager: ...@@ -428,7 +567,8 @@ class LoRAModelManager:
self._active_loras.clear() self._active_loras.clear()
def _create_lora_modules(self): def _create_lora_modules(self):
for module_name, module in self.model.named_modules(): for module_name, module in self.model.named_modules(
remove_duplicate=False):
if not self._match_target_modules(module_name): if not self._match_target_modules(module_name):
continue continue
parts = module_name.split(".")[-1] parts = module_name.split(".")[-1]
...@@ -437,6 +577,13 @@ class LoRAModelManager: ...@@ -437,6 +577,13 @@ class LoRAModelManager:
self.model, module_name, self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config, from_layer(module, self.lora_slots, self.lora_config,
packed_moduled_lst, self.model.config)) packed_moduled_lst, self.model.config))
# LinearScalingRotaryEmbeddingWithLora is used to handle
# long context lora. Register relevant metadata.
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
self.long_lora_context = LongContextLoRAContext(
new_module.scaling_factors, new_module.rotary_dim)
self.scaling_factor_to_offset = \
new_module.scaling_factor_to_offset
# (yard1): TODO make this more robust # (yard1): TODO make this more robust
if "lm_head" in module_name: if "lm_head" in module_name:
logits_processor_module = self.model.get_submodule( logits_processor_module = self.model.get_submodule(
...@@ -451,7 +598,8 @@ class LoRAModelManager: ...@@ -451,7 +598,8 @@ class LoRAModelManager:
self._register_packed_modules(module_name) self._register_packed_modules(module_name)
new_module.set_mapping(self.base_indices, self.sampler_indices, new_module.set_mapping(self.base_indices, self.sampler_indices,
self.sampler_indices_padded, self.sampler_indices_padded,
self.embeddings_indices, self.indices_len) self.embeddings_indices,
self.long_lora_indices, self.indices_len)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA) assert isinstance(module, BaseLayerWithLoRA)
...@@ -461,12 +609,14 @@ class LoRAModelManager: ...@@ -461,12 +609,14 @@ class LoRAModelManager:
self, self,
lora_id: int, lora_id: int,
rank: int, rank: int,
scaling_factor: Optional[float],
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel: embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup.""" """Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}) model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules(): for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name) or not isinstance( if not self._match_target_modules(module_name) or not isinstance(
module, BaseLayerWithLoRA): module, BaseLayerWithLoRA) or isinstance(
module, LinearScalingRotaryEmbeddingWithLora):
continue continue
parts = module_name.split(".") parts = module_name.split(".")
if module_name not in self.packed_modules: if module_name not in self.packed_modules:
...@@ -596,6 +746,10 @@ class LRUCacheLoRAModelManager(LoRAModelManager): ...@@ -596,6 +746,10 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
def add_lora(self, lora: LoRAModel) -> bool: def add_lora(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager.""" """Add a LoRAModel to the manager."""
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
if lora.id not in self._registered_loras: if lora.id not in self._registered_loras:
self._add_lora(lora) self._add_lora(lora)
was_added = True was_added = True
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
@dataclass @dataclass
...@@ -18,6 +19,7 @@ class LoRARequest: ...@@ -18,6 +19,7 @@ class LoRARequest:
lora_name: str lora_name: str
lora_int_id: int lora_int_id: int
lora_local_path: str lora_local_path: str
long_lora_max_len: Optional[int] = None
def __post_init__(self): def __post_init__(self):
if self.lora_int_id < 1: if self.lora_int_id < 1:
......
...@@ -13,6 +13,7 @@ from vllm.lora.fully_sharded_layers import ( ...@@ -13,6 +13,7 @@ from vllm.lora.fully_sharded_layers import (
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
LogitsProcessorWithLoRA, LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora, MergedQKVParallelLinearWithLora,
...@@ -26,12 +27,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead ...@@ -26,12 +27,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
logger = init_logger(__name__) logger = init_logger(__name__)
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA,
MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora, ColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA, QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLora,
RowParallelLinearWithLoRA,
LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA MergedQKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA,
LinearScalingRotaryEmbeddingWithLora,
} }
......
from abc import ABC, abstractmethod, abstractproperty from abc import ABC, abstractmethod, abstractproperty
from typing import Any, Dict, List, Set, Type from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
import torch import torch
...@@ -16,15 +17,31 @@ logger = init_logger(__name__) ...@@ -16,15 +17,31 @@ logger = init_logger(__name__)
class AbstractWorkerLoRAManager(ABC): class AbstractWorkerLoRAManager(ABC):
"""Abstract class for managing LoRA models on the worker side.""" """Abstract class for managing LoRA models on the worker side."""
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, def __init__(self,
vocab_size: int, lora_config: LoRAConfig, max_num_seqs: int,
device: torch.device): max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
max_position_embeddings: Optional[int] = None):
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
self.max_position_embeddings = max_position_embeddings
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.device = device self.device = device
self.lora_config = lora_config self.lora_config = lora_config
# If False, do not cache. If None, cache is empty.
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
@contextmanager
def dummy_lora_cache(self):
"""Use this context manager to reuse the dummy lora model
to avoid creating it repeatedly."""
self._cached_dummy_lora = None
yield
self._cached_dummy_lora = False
@abstractproperty @abstractproperty
def is_enabled(self) -> bool: def is_enabled(self) -> bool:
... ...
...@@ -80,14 +97,21 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -80,14 +97,21 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
embedding_modules: Dict[str, str], embedding_modules: Dict[str, str],
embedding_padding_modules: List[str], embedding_padding_modules: List[str],
lora_model_cls: Type[LoRAModel] = LoRAModel, lora_model_cls: Type[LoRAModel] = LoRAModel,
max_position_embeddings: Optional[int] = None,
): ):
self._lora_model_cls = lora_model_cls self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules self.embedding_padding_modules = embedding_padding_modules
# Lazily initialized by create_lora_manager. # Lazily initialized by create_lora_manager.
self._lora_manager: LoRAModelManager self._lora_manager: LoRAModelManager
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, super().__init__(
lora_config, device) max_num_seqs,
max_num_batched_tokens,
vocab_size,
lora_config,
device,
max_position_embeddings=max_position_embeddings,
)
@property @property
def is_enabled(self) -> bool: def is_enabled(self) -> bool:
...@@ -150,6 +174,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -150,6 +174,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
lora = self._lora_model_cls.from_local_checkpoint( lora = self._lora_model_cls.from_local_checkpoint(
lora_request.lora_local_path, lora_request.lora_local_path,
expected_lora_modules, expected_lora_modules,
max_position_embeddings=self.max_position_embeddings,
lora_model_id=lora_request.lora_int_id, lora_model_id=lora_request.lora_int_id,
device="cpu", device="cpu",
dtype=self.lora_config.lora_dtype, dtype=self.lora_config.lora_dtype,
...@@ -174,9 +199,15 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -174,9 +199,15 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_loras(): if lora_request.lora_int_id in self.list_loras():
return False return False
return self._lora_manager.add_lora( if isinstance(self._cached_dummy_lora, LoRAModel):
self._lora_manager.create_dummy_lora(lora_request.lora_int_id, dummy_lora = self._cached_dummy_lora.clone(
rank, self.embedding_modules)) lora_request.lora_int_id)
else:
dummy_lora = self._lora_manager.create_dummy_lora(
lora_request.lora_int_id, rank, 1, self.embedding_modules)
if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora
return self._lora_manager.add_lora(dummy_lora)
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id in self.list_loras(): if lora_request.lora_int_id in self.list_loras():
......
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_moe, get_config_file_name) fused_experts, fused_moe, fused_topk, get_config_file_name)
__all__ = [ __all__ = [
"fused_moe", "fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name", "get_config_file_name",
] ]
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_stages": 0
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"512": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8
},
"48": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_stages": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_stages": 1
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_stages": 0
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_stages": 0
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_stages": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_stages": 0
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"512": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_stages": 0
}
}
...@@ -308,60 +308,16 @@ def get_moe_configs(E: int, N: int, ...@@ -308,60 +308,16 @@ def get_moe_configs(E: int, N: int,
return None return None
def fused_moe( def fused_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
inplace: bool = False, ):
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], ( assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch") "Number of tokens mismatch")
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
M, _ = hidden_states.shape M, _ = hidden_states.shape
E, N, _ = w1.shape
if is_hip(): if is_hip():
# The MoE kernels are not yet supported on ROCm. # The MoE kernels are not yet supported on ROCm.
...@@ -393,6 +349,33 @@ def fused_moe( ...@@ -393,6 +349,33 @@ def fused_moe(
del token_expert_indicies # Not used. Will be used in the future. del token_expert_indicies # Not used. Will be used in the future.
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
M, _ = hidden_states.shape
E, N, _ = w1.shape
if override_config: if override_config:
config = override_config config = override_config
...@@ -477,3 +460,63 @@ def fused_moe( ...@@ -477,3 +460,63 @@ def fused_moe(
out=hidden_states) out=hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1) dim=1)
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
override_config=override_config,
use_fp8=use_fp8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
...@@ -59,7 +59,6 @@ class LinearMethodBase(QuantizeMethodBase): ...@@ -59,7 +59,6 @@ class LinearMethodBase(QuantizeMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights in layer to the input tensor. """Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer.""" Expects create_weights to have been called before on the layer."""
raise NotImplementedError raise NotImplementedError
...@@ -81,8 +80,7 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -81,8 +80,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
output_partition_sizes: List[int], input_size: int, output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype, output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs): **extra_weight_attrs):
output_size_per_partition = sum(output_partition_sizes) weight = Parameter(torch.empty(sum(output_partition_sizes),
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition, input_size_per_partition,
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
...@@ -161,15 +159,13 @@ class ReplicatedLinear(LinearBase): ...@@ -161,15 +159,13 @@ class ReplicatedLinear(LinearBase):
quant_config: Quantization configure. quant_config: Quantization configure.
""" """
def __init__( def __init__(self,
self, input_size: int,
input_size: int, output_size: int,
output_size: int, bias: bool = True,
bias: bool = True, skip_bias_add: bool = False,
skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None,
params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config) quant_config)
...@@ -222,17 +218,15 @@ class ColumnParallelLinear(LinearBase): ...@@ -222,17 +218,15 @@ class ColumnParallelLinear(LinearBase):
the list would be size 3. the list would be size 3.
""" """
def __init__( def __init__(self,
self, input_size: int,
input_size: int, output_size: int,
output_size: int, bias: bool = True,
bias: bool = True, gather_output: bool = False,
gather_output: bool = False, skip_bias_add: bool = False,
skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None,
params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None,
quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None):
output_sizes: Optional[List[int]] = None,
):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config) quant_config)
...@@ -240,18 +234,26 @@ class ColumnParallelLinear(LinearBase): ...@@ -240,18 +234,26 @@ class ColumnParallelLinear(LinearBase):
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, tp_size) assert self.quant_method is not None
self.output_size_per_partition = divide(self.output_size, tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, tp_size)
for output_size in self.output_sizes
]
if output_sizes is None: if output_sizes is None:
output_sizes = [output_size] output_sizes = [output_size]
# All the linear layer supports quant method. self.quant_method.create_weights(
assert self.quant_method is not None layer=self,
self.quant_method.create_weights(self, input_size_per_partition=self.input_size,
self.input_size, output_partition_sizes=self.output_partition_sizes,
[x // tp_size for x in output_sizes], input_size=self.input_size,
self.input_size, output_size=self.output_size,
self.output_size, params_dtype=self.params_dtype,
self.params_dtype, weight_loader=self.weight_loader)
weight_loader=self.weight_loader)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size_per_partition, torch.empty(self.output_size_per_partition,
...@@ -333,24 +335,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -333,24 +335,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config: Quantization configure. quant_config: Quantization configure.
""" """
def __init__( def __init__(self,
self, input_size: int,
input_size: int, output_sizes: List[int],
output_sizes: List[int], bias: bool = True,
bias: bool = True, gather_output: bool = False,
gather_output: bool = False, skip_bias_add: bool = False,
skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None,
params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
):
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output, super().__init__(input_size=input_size,
skip_bias_add, params_dtype, quant_config, output_size=sum(output_sizes),
self.output_sizes) bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
...@@ -360,6 +365,26 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -360,6 +365,26 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks. # Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False) is_metadata = getattr(param, "is_metadata", False)
param_shard_splitter = getattr(param, "shard_splitter", None)
if output_dim is not None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support output_dim != None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# If a parameter has defined a shard_splitter to be used for
# the weight, it should be applied before the weight is
# loaded/copied to the parameter. The shard_splitter applies
# logic by using the loaded_shard_id to ensure that the loaded
# param is loaded to the correct location
# within the parameter defined by the linear method.
if loaded_shard_id is None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support loaded_shard_id == None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# Special case for Fp8 scales. # Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None) None)
...@@ -424,6 +449,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -424,6 +449,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size = loaded_weight.shape[0] shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size) param_data = param_data.narrow(0, shard_offset, shard_size)
# If a param_shard_splitter is defined by the LinearMethod, use it.
elif param_shard_splitter is not None:
logical_widths = getattr(param, "logical_widths", None)
param_data, loaded_weight = param_shard_splitter(
param_data, loaded_weight, loaded_shard_id, logical_widths)
# Special case for Fp8 scales. # Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None: elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer( param_data, loaded_weight = fp8_scales_shard_indexer(
...@@ -436,7 +468,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -436,7 +468,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in " "Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is " "MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.") "the same for all partitions.")
if fp8_scales_shard_indexer is None:
if len(param_data.shape) == 0:
param_data = param_data.reshape(1)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if self.use_llama_nn: if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight) param_data_.copy_(loaded_weight)
...@@ -448,6 +487,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -448,6 +487,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear): class QKVParallelLinear(ColumnParallelLinear):
"""Linear layers for the attention's QKV transformation. """Linear layers for the attention's QKV transformation.
...@@ -472,17 +512,15 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -472,17 +512,15 @@ class QKVParallelLinear(ColumnParallelLinear):
quant_config: Quantization configure. quant_config: Quantization configure.
""" """
def __init__( def __init__(self,
self, hidden_size: int,
hidden_size: int, head_size: int,
head_size: int, total_num_heads: int,
total_num_heads: int, total_num_kv_heads: Optional[int] = None,
total_num_kv_heads: Optional[int] = None, bias: bool = True,
bias: bool = True, skip_bias_add: bool = False,
skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None,
params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = head_size self.head_size = head_size
self.total_num_heads = total_num_heads self.total_num_heads = total_num_heads
...@@ -502,14 +540,18 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -502,14 +540,18 @@ class QKVParallelLinear(ColumnParallelLinear):
input_size = self.hidden_size input_size = self.hidden_size
output_size = (self.num_heads + output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size 2 * self.num_kv_heads) * tp_size * self.head_size
output_sizes = [ self.output_sizes = [
self.num_heads * tp_size * self.head_size, self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * tp_size * self.head_size, self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * tp_size * self.head_size self.num_kv_heads * self.head_size * tp_size, # v_proj
] ]
super().__init__(input_size=input_size,
super().__init__(input_size, output_size, bias, False, skip_bias_add, output_size=output_size,
params_dtype, quant_config, output_sizes) bias=bias,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, def weight_loader(self,
...@@ -520,6 +562,26 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -520,6 +562,26 @@ class QKVParallelLinear(ColumnParallelLinear):
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks. # Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False) is_metadata = getattr(param, "is_metadata", False)
param_shard_splitter = getattr(param, "shard_splitter", None)
if output_dim is not None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support output_dim != None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# If a parameter has defined a shard_splitter to be used for
# the weight, it should be applied before the weight is
# loaded/copied to the parameter. The shard_splitter applies
# logic by using the loaded_shard_id to ensure that the loaded
# param is loaded to the correct location
# within the parameter defined by the linear method.
if loaded_shard_id is None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support loaded_shard_id == None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# Special case for Fp8 scales. # Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None) None)
...@@ -558,6 +620,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -558,6 +620,8 @@ class QKVParallelLinear(ColumnParallelLinear):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
assert loaded_shard_id in ["q", "k", "v"] assert loaded_shard_id in ["q", "k", "v"]
# If output dim is defined, use the default loading process.
if output_dim is not None: if output_dim is not None:
if loaded_shard_id == "q": if loaded_shard_id == "q":
shard_offset = 0 shard_offset = 0
...@@ -601,6 +665,12 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -601,6 +665,12 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_index = ["q", "k", "v"].index(loaded_shard_id) shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size, param_data = param_data.narrow(0, shard_index * shard_size,
shard_size) shard_size)
# If a param_shard_splitter is defined by the LinearMethod, use it.
elif param_shard_splitter is not None:
logical_widths = getattr(param, "logical_widths", None)
param_data, loaded_weight = param_shard_splitter(
param_data, loaded_weight, loaded_shard_id, logical_widths)
# Special case for Fp8 scales. # Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None: elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer( param_data, loaded_weight = fp8_scales_shard_indexer(
...@@ -612,6 +682,11 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -612,6 +682,11 @@ class QKVParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in " "Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same " "QKVParallelLinear, assume the weight is the same "
"for all partitions.") "for all partitions.")
if len(param_data.shape) == 0:
param_data = param_data.reshape(1)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if self.use_llama_nn: if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape assert param_data_.shape == loaded_weight.shape
...@@ -650,17 +725,15 @@ class RowParallelLinear(LinearBase): ...@@ -650,17 +725,15 @@ class RowParallelLinear(LinearBase):
quant_config: Quantization configure. quant_config: Quantization configure.
""" """
def __init__( def __init__(self,
self, input_size: int,
input_size: int, output_size: int,
output_size: int, bias: bool = True,
bias: bool = True, input_is_parallel: bool = True,
input_is_parallel: bool = True, skip_bias_add: bool = False,
skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None,
params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True,
reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config) quant_config)
...@@ -670,16 +743,15 @@ class RowParallelLinear(LinearBase): ...@@ -670,16 +743,15 @@ class RowParallelLinear(LinearBase):
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
# All the linear layer supports quant method.
assert self.quant_method is not None assert self.quant_method is not None
self.quant_method.create_weights(self, self.quant_method.create_weights(
self.input_size_per_partition, layer=self,
[self.output_size], input_size_per_partition=self.input_size_per_partition,
self.input_size, output_partition_sizes=[self.output_size],
self.output_size, input_size=self.input_size,
self.params_dtype, output_size=self.output_size,
weight_loader=self.weight_loader) params_dtype=self.params_dtype,
weight_loader=self.weight_loader)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results") "results can lead to incorrect results")
...@@ -708,12 +780,16 @@ class RowParallelLinear(LinearBase): ...@@ -708,12 +780,16 @@ class RowParallelLinear(LinearBase):
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx, loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size) shard_size)
# Special case for Fp8 scales. # Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None: elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(param_data, param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
loaded_weight, loaded_weight,
shard_id=0) shard_id=0)
if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
if self.use_llama_nn: if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(0, 1)
......
"""A layer that compute logits from hidden_stats.""" """A layer that compute logits from hidden_stats."""
import inspect
from typing import Optional from typing import Optional
import torch import torch
...@@ -95,15 +96,25 @@ def _apply_logits_processors( ...@@ -95,15 +96,25 @@ def _apply_logits_processors(
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
logits_processors = sampling_params.logits_processors logits_processors = sampling_params.logits_processors
if logits_processors: if logits_processors:
found_logits_processors = True found_logits_processors = True
for seq_id, logits_row_idx in zip(seq_ids, for seq_id, logits_row_idx in zip(seq_ids,
seq_group.sample_indices): seq_group.sample_indices):
logits_row = logits[logits_row_idx] logits_row = logits[logits_row_idx]
token_ids = seq_group.seq_data[seq_id].output_token_ids past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
for logits_processor in logits_processors: for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row) parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
logits_row = logits_processor(prompt_tokens_ids,
past_tokens_ids,
logits_row)
else:
logits_row = logits_processor(past_tokens_ids,
logits_row)
logits[logits_row_idx] = logits_row logits[logits_row_idx] = logits_row
logits_processed += len(seq_group.sample_indices) + len( logits_processed += len(seq_group.sample_indices) + len(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment