"tests/vscode:/vscode.git/clone" did not exist on "402759d4727d9a377598a09d06770770d4e184c6"
Commit b9e12416 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.3

parents e5d707db e9d3aa04
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerWrapperBase
......@@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
If speculative decoding is enabled, we instead create the speculative
worker.
"""
if self.speculative_config is None:
self._init_non_spec_worker()
else:
self._init_spec_worker()
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _get_worker_kwargs(
self,
......@@ -45,6 +44,7 @@ class GPUExecutor(ExecutorBase):
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
speculative_config=self.speculative_config,
is_driver_worker=rank == 0,
)
......@@ -52,53 +52,22 @@ class GPUExecutor(ExecutorBase):
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
if self.speculative_config is None:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
else:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
def _init_non_spec_worker(self):
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _init_spec_worker(self):
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
"""
assert self.speculative_config is not None
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
target_worker = self._create_worker()
draft_worker_kwargs = self._get_worker_kwargs()
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=self.speculative_config.draft_model_config,
parallel_config=self.speculative_config.draft_parallel_config,
# TODO allow draft-model specific load config.
#load_config=self.load_config,
)
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
)
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = spec_decode_worker
# Load model handled in spec decode worker.
self.driver_worker.init_device()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
......@@ -117,8 +86,8 @@ class GPUExecutor(ExecutorBase):
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> List[Union[SamplerOutput, PoolerOutput]]:
output = self.driver_worker.execute_model(execute_model_req)
return output
......@@ -144,7 +113,7 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
) -> List[Union[SamplerOutput, PoolerOutput]]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )
return output
import asyncio
import os
from functools import partial
from typing import Any, List, Optional
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
logger = init_logger(__name__)
class MultiprocessingGPUExecutor(DistributedGPUExecutor):
"""Python multiprocessing-based multi-GPU executor"""
def _init_executor(self) -> None:
assert (
not self.speculative_config
), "Speculative decoding not yet supported for MultiProcGPU backend."
# Create the parallel GPU workers.
world_size = self.parallel_config.tensor_parallel_size
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = (",".join(
map(str, range(world_size))))
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
from torch.cuda import device_count
assert world_size <= device_count(), (
"please set tensor_parallel_size to less than max local gpu count")
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
if world_size == 1:
self.workers = []
else:
result_handler = ResultHandler()
self.workers = [
ProcessWorkerWrapper(
result_handler,
partial(
self._create_worker,
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
)) for rank in range(1, world_size)
]
self.worker_monitor = WorkerMonitor(self.workers, result_handler)
result_handler.start()
self.worker_monitor.start()
self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor",
None)) is not None:
worker_monitor.close()
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_model(
execute_model_req=execute_model_req)
def _run_workers(
self,
method: str,
*args,
async_run_remote_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
"""
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
# Start the workers first.
worker_outputs = [
worker.execute_method(method, *args, **kwargs)
for worker in self.workers
]
if async_run_remote_workers_only:
# Just return futures
return worker_outputs
driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*args, **kwargs)
# Get the results of the workers.
return [driver_worker_output
] + [output.get() for output in worker_outputs]
def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
if not self.worker_monitor.is_alive():
raise RuntimeError("Worker processes are not running")
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
for result in parallel_worker_tasks:
result.get()
class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_model(execute_model_req)
async def _start_worker_execution_loop(self):
coros = [
worker.execute_method_async("start_worker_execution_loop")
for worker in self.workers
]
return await asyncio.gather(*coros)
......@@ -28,10 +28,7 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayGPUExecutor(DistributedGPUExecutor):
def _init_executor(self) -> None:
assert (not self.speculative_config
), "Speculative decoding not yet supported for RayGPU backend."
assert self.parallel_config.worker_use_ray
assert self.parallel_config.distributed_executor_backend == "ray"
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
......@@ -45,6 +42,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
self.extra_execute_model_run_workers_kwargs[
"use_ray_compiled_dag"] = True
def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
......@@ -90,14 +89,22 @@ class RayGPUExecutor(DistributedGPUExecutor):
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
......@@ -107,8 +114,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
else:
......@@ -166,23 +173,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
all_outputs = self._run_workers(
"execute_model",
driver_kwargs={"execute_model_req": execute_model_req},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.
# Only the driver worker returns the sampling results.
return all_outputs[0]
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_method("execute_model",
execute_model_req)
def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
async_run_remote_workers_only: bool = False,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
......@@ -193,9 +200,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
"""Runs the given method on all workers. Can be used in the following
ways:
- async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than blocking
on the results.
- args/kwargs: All workers share the same args/kwargs
- args/kwargs and driver_args/driver_kwargs: Driver worker has
different args
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
......@@ -204,11 +213,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
if driver_args is None:
driver_args = args if all_args is None else all_args[0]
if driver_kwargs is None:
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
......@@ -220,6 +224,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1)
ray_worker_outputs = []
else:
# Start the ray workers first.
ray_worker_outputs = [
......@@ -229,6 +234,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
if async_run_remote_workers_only:
# Just return futures
return ray_worker_outputs
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
......@@ -255,6 +267,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
return [driver_worker_output] + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)
def _compiled_ray_dag(self):
import pkg_resources
required_version = "2.9"
......@@ -264,7 +281,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.worker_use_ray
assert self.parallel_config.distributed_executor_backend == "ray"
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
......@@ -298,30 +315,18 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_executor = make_async(self.driver_worker.execute_method)
self.driver_exec_method = make_async(self.driver_worker.execute_method)
async def _run_workers_async(
async def _driver_execute_model_async(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
coros = []
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
coros.append(
self.driver_executor(method, *driver_args, **driver_kwargs))
# Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))
all_outputs = await asyncio.gather(*coros)
return all_outputs
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_method("execute_model",
execute_model_req)
async def _start_worker_execution_loop(self):
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.workers
]
return await asyncio.gather(*coros)
......@@ -44,7 +44,7 @@ try:
except ImportError as e:
logger.warning(
"Failed to import Ray with %r. For distributed inference, "
"Failed to import Ray with %r. For multi-node inference, "
"please install Ray with `pip install ray`.", e)
ray = None # type: ignore
RayWorkerWrapper = None # type: ignore
......@@ -67,7 +67,7 @@ def initialize_ray_cluster(
"""
if ray is None:
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"Ray is not installed. Please install Ray to use multi-node "
"serving.")
# Connect to a ray cluster.
......
from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
TypedDict, Union, cast, overload)
from typing_extensions import NotRequired
if TYPE_CHECKING:
from vllm.sequence import MultiModalData
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
# https://github.com/vllm-project/vllm/pull/4028
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0], str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False)
for elem in cast(List[str], prompt)
]
if isinstance(prompt[0], int):
# case 3: array of tokens
elem = cast(List[int], prompt)
return [ParsedTokens(content=elem, is_tokens=True)]
if isinstance(prompt[0], list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0][0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in cast(List[List[int]], prompt)
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
class TextPrompt(TypedDict):
"""Schema for a text prompt."""
prompt: str
"""The input text to be tokenized before passing to the model."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class TokensPrompt(TypedDict):
"""Schema for a tokenized prompt."""
prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class TextTokensPrompt(TypedDict):
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt: str
"""The prompt text."""
prompt_token_ids: List[int]
"""The token IDs of the prompt. If None, we use the
tokenizer to convert the prompts to token IDs."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
"""
The inputs to the LLM, which can take one of the following forms:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
"""
PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
"""Same as :const:`PromptStrictInputs` but additionally accepts
:class:`TextTokensPrompt`."""
class LLMInputs(TypedDict):
prompt_token_ids: List[int]
prompt: NotRequired[Optional[str]]
multi_modal_data: NotRequired[Optional["MultiModalData"]]
......@@ -14,6 +14,7 @@ import vllm.envs as envs
VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S"
......@@ -30,7 +31,7 @@ DEFAULT_LOGGING_CONFIG = {
"vllm": {
"class": "logging.StreamHandler",
"formatter": "vllm",
"level": "INFO",
"level": VLLM_LOGGING_LEVEL,
"stream": "ext://sys.stdout",
},
},
......
# pylint: disable=unused-argument
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Optional, Union
import torch
import torch.nn as nn
......@@ -51,10 +51,9 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
......@@ -88,7 +87,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
)
def _mcp_apply_weights(x, bias, layer):
def _mcp_apply(x, bias, layer):
"""
MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same
......@@ -100,8 +99,7 @@ def _mcp_apply_weights(x, bias, layer):
"""
# expecting 2 for column parallel and 3 for qkv
n = len(layer.lora_a_stacked)
output = layer.base_layer.linear_method.apply_weights(
layer.base_layer, x, bias)
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
......@@ -136,18 +134,23 @@ class MergedColumnParallelLinearWithShardedLoRA(
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_a[0] is None or lora_a[1] is None:
return lora_a
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[i][:, output_start_idx:output_start_idx + output_shard_size]
for i in range(2)
lora_a[0][:,
output_start_idx:output_start_idx + output_shard_size],
lora_a[1][:, output_start_idx:output_start_idx + output_shard_size]
]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
......@@ -172,19 +175,23 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None:
return lora_a
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[i][:, start_idx[i]:start_idx[i] +
shard_size[i]] if lora_a[i] is not None else None
for i in range(3)
lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]],
lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]],
lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]]
]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
......@@ -218,9 +225,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
lora_b = lora_b[:, start_idx:end_idx]
return lora_b
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x)
def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
......
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
......@@ -22,6 +22,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import (
LinearScalingRotaryEmbedding, RotaryEmbedding)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -145,11 +147,15 @@ class LoRAMapping:
class BaseLayerWithLoRA(nn.Module):
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
def slice_lora_a(
self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
"""Slice lora a if splitting for tensor parallelism."""
...
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
def slice_lora_b(
self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
"""Slice lora b if splitting with tensor parallelism."""
...
......@@ -181,6 +187,7 @@ class BaseLayerWithLoRA(nn.Module):
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
"""Sets the mapping indices."""
......@@ -302,6 +309,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
......@@ -427,6 +435,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
......@@ -539,10 +548,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.lora_b_stacked[0][index] = 0
self.lora_b_stacked[1][index] = 0
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
return lora_a
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_b[0] is None or lora_b[1] is None:
return lora_b
shard_size = self.output_dim
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
......@@ -767,10 +782,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.lora_a_stacked[2][index] = 0
self.lora_b_stacked[2][index] = 0
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
return lora_a
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
lora_b_q, lora_b_k, lora_b_v = None, None, None
if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
......@@ -936,6 +956,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
......@@ -992,7 +1013,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
@property
def weight(self):
return self.base_layer.weight if hasattr(
self.base_layer, "weight") else self.base_layer.qweight
......@@ -1113,6 +1133,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = sampler_indices
......@@ -1179,3 +1200,101 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config: Optional[PretrainedConfig]) -> bool:
# Special handling for the LogitsProcessor.
return False
class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
"""Implements RoPE-scaled embeddings with linear scaling for
multiple LoRA adapters with a specialized kernel.
Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
which can handle multi lora adapters in a specialied kernel.
"""
def __init__(self, base_layer: RotaryEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
# Lazily initialized
self.long_lora_indices: torch.Tensor
self.indices_len: List[int]
@property
def scaling_factors(self):
return self.base_layer.scaling_factors
@property
def rotary_dim(self):
return self.base_layer.rotary_dim
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
scaling_factors = list(
lora_config.long_lora_scaling_factors
) if lora_config.long_lora_scaling_factors else []
base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
scaling_factors = sorted(
list(set([base_scaling_factor] + scaling_factors)))
self.base_layer = LinearScalingRotaryEmbedding(
self.base_layer.head_size,
self.base_layer.rotary_dim,
self.base_layer.max_position_embeddings,
self.base_layer.base,
self.base_layer.is_neox_style,
scaling_factors,
self.base_layer.dtype,
)
def reset_lora(self, index: int):
...
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
...
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.long_lora_indices = long_lora_indices
self.indices_len = indices_len
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
return self.base_layer(
positions,
query,
key,
offsets=self.long_lora_indices[:self.indices_len[4]])
@property
def scaling_factor_to_offset(self) -> Dict[float, int]:
return self.base_layer.scaling_factor_to_offset
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
return type(source_layer) is LinearScalingRotaryEmbedding or type(
source_layer) is RotaryEmbedding
def extra_repr(self) -> str:
return self.base_layer.extra_repr()
......@@ -3,7 +3,8 @@ import json
import math
import os
import re
from typing import Callable, Dict, List, Optional, Tuple, Type
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import safetensors.torch
import torch
......@@ -11,7 +12,9 @@ from torch import nn
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
from vllm.lora.layers import (BaseLayerWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
......@@ -22,10 +25,27 @@ logger = init_logger(__name__)
_GLOBAL_LORA_ID = 0
@dataclass
class LongContextLoRAContext:
"""Context for lora adapters that support long context."""
# The scaling factors to support long context lora fine tuned models.
scaling_factors: List[float]
# dimension to apply rotary embedding.
rot_dim: int
# offsets to the sin_cos_cache for each lora_id loaded.
# This value is dynamically modified.
offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
def convert_mapping(
mapping: LoRAMapping, lora_index_to_id: List[Optional[int]],
max_loras: int, vocab_size: int, extra_vocab_size: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
mapping: LoRAMapping,
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional[LongContextLoRAContext] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], List[int]]:
"""Converts LoRAMapping to index tensors.
Args:
......@@ -34,6 +54,7 @@ def convert_mapping(
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
......@@ -51,11 +72,23 @@ def convert_mapping(
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors.
Used to index into each tensor. It contains length for
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices). If long_lora doesn't
exist, it only contains first 4 entries.
"""
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device="cuda",
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
......@@ -66,13 +99,20 @@ def convert_mapping(
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
index_mapping_indices[i] = i
lora_indices[i] = lora_idx
indices = torch.tensor(
[index_mapping_indices, lora_indices, embedding_indices],
dtype=torch.long,
device="cuda")
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices, lora_indices, embedding_indices
]
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
......@@ -89,13 +129,21 @@ def convert_mapping(
torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
long_lora_indices = indices[3]
long_lora_indices_len = long_lora_indices.shape[-1]
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
]
if long_lora_indices_len is not None:
indices_len.append(long_lora_indices_len)
return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, indices_len)
embeddings_indices, long_lora_indices, indices_len)
def get_lora_id():
......@@ -112,13 +160,35 @@ class LoRAModel:
lora_model_id: int,
rank: int,
loras: Dict[str, LoRALayerWeights],
scaling_factor: Optional[float] = None,
) -> None:
"""
Args:
lora_model_id: The integer id for the lora model.
rank: lora rank.
loras: module name -> weights for lora-replaced layers.
scaling_factor: Scaling factor to support long context lora model.
None if the lora is not tuned for long context support.
"""
self.id = lora_model_id
# Scaling factor for long context lora model. None if it is not
# fine tuned for the long context.
self.scaling_factor = scaling_factor
assert (lora_model_id >
0), f"a valid lora id should be greater than 0, got {self.id}"
self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras
def clone(self, lora_model_id: int) -> "LoRAModel":
"""Return a copy of the object with different ids.
Will share the underlying tensors."""
return self.__class__(
lora_model_id,
rank=self.rank,
loras=self.loras.copy(),
)
@property
def extra_vocab_size(self) -> int:
return max(lora.extra_vocab_size
......@@ -140,6 +210,7 @@ class LoRAModel:
dtype: Optional[torch.dtype] = None,
embeddings: Optional[Dict[str, torch.Tensor]] = None,
target_embedding_padding: Optional[int] = None,
scaling_factor: Optional[float] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel":
......@@ -189,13 +260,15 @@ class LoRAModel:
for lora in loras.values():
lora.optimize()
return cls(lora_model_id, rank, loras)
return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
expected_lora_modules: List[str],
*,
max_position_embeddings: Optional[int] = None,
lora_model_id: Optional[int] = None,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
......@@ -203,7 +276,23 @@ class LoRAModel:
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint."""
"""Create a LoRAModel from a local checkpoint.
Args:
lora_dir: The local path that has lora data.
expected_lora_modules: Name of modules that are expected to be
replaced by lora.
max_position_embeddings: Max position embedding length. Used to
scaling the largest context length. If None, the lora model's
context length is not scaled.
lora_model_id: Lora model id. If not given, automatically set by
a global counter.
device: Device where the lora model is loaded.
dtype: dtype of the lora model weights.
Returns:
Loaded LoRA Model.
"""
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
......@@ -221,7 +310,9 @@ class LoRAModel:
if part_name not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of expected_lora_modules
if unexpected_modules:
print(unexpected_modules, "modules")
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
......@@ -243,6 +334,14 @@ class LoRAModel:
rank = config["r"]
lora_alpha = config["lora_alpha"]
context_length = config.get("context_length", None)
scaling_factor = None
if context_length:
if max_position_embeddings is None:
max_position_embeddings = context_length
scaling_factor = float(
math.ceil(context_length / max_position_embeddings))
return cls.from_lora_tensors(
lora_model_id=get_lora_id()
if lora_model_id is None else lora_model_id,
......@@ -253,6 +352,7 @@ class LoRAModel:
dtype=dtype,
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
scaling_factor=scaling_factor,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules,
)
......@@ -286,6 +386,7 @@ class LoRAModelManager:
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None
self.base_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
......@@ -299,6 +400,12 @@ class LoRAModelManager:
self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.long_lora_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {}
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
......@@ -308,6 +415,10 @@ class LoRAModelManager:
if hasattr(self.model, "supported_lora_modules"):
self.supported_lora_modules = copy.deepcopy(
self.model.supported_lora_modules)
if lora_config.long_lora_scaling_factors:
# We need to replace rotary emb layer to do batch computation
# for long lora.
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping)
self.packed_modules: Dict[str, List[str]] = {}
......@@ -373,12 +484,32 @@ class LoRAModelManager:
return True
return False
def _set_long_lora_context(self, lora: LoRAModel):
if self.long_lora_context is None:
return
if lora.scaling_factor is None:
return
if (lora.scaling_factor not in self.scaling_factor_to_offset):
raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
" has not been initialized.")
offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
if offsets:
self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
def _add_lora(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora
self._set_long_lora_context(lora)
def add_lora(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager CPU cache."""
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
if lora.id not in self._registered_loras:
if len(self._registered_loras) >= self.capacity:
raise RuntimeError("No free LoRA slots.")
......@@ -390,15 +521,18 @@ class LoRAModelManager:
"""Remove a LoRAModel from the manager CPU cache."""
# TODO: should we check active lora?
self.deactivate_lora(lora_id)
if self.long_lora_context:
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
return bool(self._registered_loras.pop(lora_id, None))
# TODO see if this can be vectorized
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices,
embeddings_indices, long_lora_offsets_tensor,
indices_len) = convert_mapping(mapping, self.lora_index_to_id,
self.lora_slots + 1, self.vocab_size,
self.lora_config.lora_extra_vocab_size)
self.lora_config.lora_extra_vocab_size,
self.long_lora_context)
self.base_indices[:base_indices.shape[0]].copy_(base_indices)
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
......@@ -406,6 +540,11 @@ class LoRAModelManager:
self.embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices)
if long_lora_offsets_tensor is not None:
self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
long_lora_offsets_tensor)
else:
self.long_lora_indices.zero_()
# Maintain the reference
self.indices_len[:] = indices_len
......@@ -428,7 +567,8 @@ class LoRAModelManager:
self._active_loras.clear()
def _create_lora_modules(self):
for module_name, module in self.model.named_modules():
for module_name, module in self.model.named_modules(
remove_duplicate=False):
if not self._match_target_modules(module_name):
continue
parts = module_name.split(".")[-1]
......@@ -437,6 +577,13 @@ class LoRAModelManager:
self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config,
packed_moduled_lst, self.model.config))
# LinearScalingRotaryEmbeddingWithLora is used to handle
# long context lora. Register relevant metadata.
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
self.long_lora_context = LongContextLoRAContext(
new_module.scaling_factors, new_module.rotary_dim)
self.scaling_factor_to_offset = \
new_module.scaling_factor_to_offset
# (yard1): TODO make this more robust
if "lm_head" in module_name:
logits_processor_module = self.model.get_submodule(
......@@ -451,7 +598,8 @@ class LoRAModelManager:
self._register_packed_modules(module_name)
new_module.set_mapping(self.base_indices, self.sampler_indices,
self.sampler_indices_padded,
self.embeddings_indices, self.indices_len)
self.embeddings_indices,
self.long_lora_indices, self.indices_len)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA)
......@@ -461,12 +609,14 @@ class LoRAModelManager:
self,
lora_id: int,
rank: int,
scaling_factor: Optional[float],
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {})
model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name) or not isinstance(
module, BaseLayerWithLoRA):
module, BaseLayerWithLoRA) or isinstance(
module, LinearScalingRotaryEmbeddingWithLora):
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
......@@ -596,6 +746,10 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
def add_lora(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager."""
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
if lora.id not in self._registered_loras:
self._add_lora(lora)
was_added = True
......
from dataclasses import dataclass
from typing import Optional
@dataclass
......@@ -18,6 +19,7 @@ class LoRARequest:
lora_name: str
lora_int_id: int
lora_local_path: str
long_lora_max_len: Optional[int] = None
def __post_init__(self):
if self.lora_int_id < 1:
......
......@@ -13,6 +13,7 @@ from vllm.lora.fully_sharded_layers import (
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
......@@ -26,12 +27,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
logger = init_logger(__name__)
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA,
LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA,
VocabParallelEmbeddingWithLoRA,
ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLora,
RowParallelLinearWithLoRA,
LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA
MergedQKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA,
LinearScalingRotaryEmbeddingWithLora,
}
......
from abc import ABC, abstractmethod, abstractproperty
from typing import Any, Dict, List, Set, Type
from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
import torch
......@@ -16,15 +17,31 @@ logger = init_logger(__name__)
class AbstractWorkerLoRAManager(ABC):
"""Abstract class for managing LoRA models on the worker side."""
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
vocab_size: int, lora_config: LoRAConfig,
device: torch.device):
def __init__(self,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
max_position_embeddings: Optional[int] = None):
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_position_embeddings = max_position_embeddings
self.vocab_size = vocab_size
self.device = device
self.lora_config = lora_config
# If False, do not cache. If None, cache is empty.
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
@contextmanager
def dummy_lora_cache(self):
"""Use this context manager to reuse the dummy lora model
to avoid creating it repeatedly."""
self._cached_dummy_lora = None
yield
self._cached_dummy_lora = False
@abstractproperty
def is_enabled(self) -> bool:
...
......@@ -80,14 +97,21 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
embedding_modules: Dict[str, str],
embedding_padding_modules: List[str],
lora_model_cls: Type[LoRAModel] = LoRAModel,
max_position_embeddings: Optional[int] = None,
):
self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules
# Lazily initialized by create_lora_manager.
self._lora_manager: LoRAModelManager
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
lora_config, device)
super().__init__(
max_num_seqs,
max_num_batched_tokens,
vocab_size,
lora_config,
device,
max_position_embeddings=max_position_embeddings,
)
@property
def is_enabled(self) -> bool:
......@@ -150,6 +174,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
lora = self._lora_model_cls.from_local_checkpoint(
lora_request.lora_local_path,
expected_lora_modules,
max_position_embeddings=self.max_position_embeddings,
lora_model_id=lora_request.lora_int_id,
device="cpu",
dtype=self.lora_config.lora_dtype,
......@@ -174,9 +199,15 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_loras():
return False
return self._lora_manager.add_lora(
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
rank, self.embedding_modules))
if isinstance(self._cached_dummy_lora, LoRAModel):
dummy_lora = self._cached_dummy_lora.clone(
lora_request.lora_int_id)
else:
dummy_lora = self._lora_manager.create_dummy_lora(
lora_request.lora_int_id, rank, 1, self.embedding_modules)
if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora
return self._lora_manager.add_lora(dummy_lora)
def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id in self.list_loras():
......
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_moe, get_config_file_name)
fused_experts, fused_moe, fused_topk, get_config_file_name)
__all__ = [
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
]
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_stages": 0
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"512": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8
},
"48": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_stages": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_stages": 1
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_stages": 0
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_stages": 0
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_stages": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_stages": 0
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_stages": 1
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"512": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 0
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0
},
"4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_stages": 0
}
}
......@@ -308,60 +308,16 @@ def get_moe_configs(E: int, N: int,
return None
def fused_moe(
def fused_topk(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
M, _ = hidden_states.shape
E, N, _ = w1.shape
if is_hip():
# The MoE kernels are not yet supported on ROCm.
......@@ -393,6 +349,33 @@ def fused_moe(
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
M, _ = hidden_states.shape
E, N, _ = w1.shape
if override_config:
config = override_config
......@@ -477,3 +460,63 @@ def fused_moe(
out=hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
override_config=override_config,
use_fp8=use_fp8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
......@@ -59,7 +59,6 @@ class LinearMethodBase(QuantizeMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
......@@ -81,8 +80,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
output_size_per_partition = sum(output_partition_sizes)
weight = Parameter(torch.empty(output_size_per_partition,
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
......@@ -161,15 +159,13 @@ class ReplicatedLinear(LinearBase):
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
......@@ -222,17 +218,15 @@ class ColumnParallelLinear(LinearBase):
the list would be size 3.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
):
def __init__(self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
......@@ -240,18 +234,26 @@ class ColumnParallelLinear(LinearBase):
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, tp_size)
assert self.quant_method is not None
self.output_size_per_partition = divide(self.output_size, tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, tp_size)
for output_size in self.output_sizes
]
if output_sizes is None:
output_sizes = [output_size]
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size,
[x // tp_size for x in output_sizes],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
......@@ -333,24 +335,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_sizes: List[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self,
input_size: int,
output_sizes: List[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, quant_config,
self.output_sizes)
super().__init__(input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
......@@ -360,6 +365,26 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
param_shard_splitter = getattr(param, "shard_splitter", None)
if output_dim is not None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support output_dim != None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# If a parameter has defined a shard_splitter to be used for
# the weight, it should be applied before the weight is
# loaded/copied to the parameter. The shard_splitter applies
# logic by using the loaded_shard_id to ensure that the loaded
# param is loaded to the correct location
# within the parameter defined by the linear method.
if loaded_shard_id is None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support loaded_shard_id == None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
......@@ -424,6 +449,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)
# If a param_shard_splitter is defined by the LinearMethod, use it.
elif param_shard_splitter is not None:
logical_widths = getattr(param, "logical_widths", None)
param_data, loaded_weight = param_shard_splitter(
param_data, loaded_weight, loaded_shard_id, logical_widths)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
......@@ -436,7 +468,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
if fp8_scales_shard_indexer is None:
if len(param_data.shape) == 0:
param_data = param_data.reshape(1)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
......@@ -448,6 +487,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear):
"""Linear layers for the attention's QKV transformation.
......@@ -472,17 +512,15 @@ class QKVParallelLinear(ColumnParallelLinear):
quant_config: Quantization configure.
"""
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
......@@ -502,14 +540,18 @@ class QKVParallelLinear(ColumnParallelLinear):
input_size = self.hidden_size
output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size
output_sizes = [
self.num_heads * tp_size * self.head_size,
self.num_kv_heads * tp_size * self.head_size,
self.num_kv_heads * tp_size * self.head_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
super().__init__(input_size, output_size, bias, False, skip_bias_add,
params_dtype, quant_config, output_sizes)
super().__init__(input_size=input_size,
output_size=output_size,
bias=bias,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self,
......@@ -520,6 +562,26 @@ class QKVParallelLinear(ColumnParallelLinear):
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
param_shard_splitter = getattr(param, "shard_splitter", None)
if output_dim is not None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support output_dim != None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# If a parameter has defined a shard_splitter to be used for
# the weight, it should be applied before the weight is
# loaded/copied to the parameter. The shard_splitter applies
# logic by using the loaded_shard_id to ensure that the loaded
# param is loaded to the correct location
# within the parameter defined by the linear method.
if loaded_shard_id is None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support loaded_shard_id == None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
......@@ -558,6 +620,8 @@ class QKVParallelLinear(ColumnParallelLinear):
tp_rank = get_tensor_model_parallel_rank()
assert loaded_shard_id in ["q", "k", "v"]
# If output dim is defined, use the default loading process.
if output_dim is not None:
if loaded_shard_id == "q":
shard_offset = 0
......@@ -601,6 +665,12 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size,
shard_size)
# If a param_shard_splitter is defined by the LinearMethod, use it.
elif param_shard_splitter is not None:
logical_widths = getattr(param, "logical_widths", None)
param_data, loaded_weight = param_shard_splitter(
param_data, loaded_weight, loaded_shard_id, logical_widths)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
......@@ -612,6 +682,11 @@ class QKVParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
if len(param_data.shape) == 0:
param_data = param_data.reshape(1)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape
......@@ -650,17 +725,15 @@ class RowParallelLinear(LinearBase):
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
......@@ -670,16 +743,15 @@ class RowParallelLinear(LinearBase):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size_per_partition,
[self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=[self.output_size],
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
......@@ -708,12 +780,16 @@ class RowParallelLinear(LinearBase):
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
loaded_weight,
shard_id=0)
if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
......
"""A layer that compute logits from hidden_stats."""
import inspect
from typing import Optional
import torch
......@@ -95,15 +96,25 @@ def _apply_logits_processors(
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id, logits_row_idx in zip(seq_ids,
seq_group.sample_indices):
logits_row = logits[logits_row_idx]
token_ids = seq_group.seq_data[seq_id].output_token_ids
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
logits_row = logits_processor(prompt_tokens_ids,
past_tokens_ids,
logits_row)
else:
logits_row = logits_processor(past_tokens_ids,
logits_row)
logits[logits_row_idx] = logits_row
logits_processed += len(seq_group.sample_indices) + len(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment