Commit e661d594 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1

parents 6b16ea2e 4db5176d
from typing import List, Set, Tuple
import os
from functools import partial
from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
import torch
import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.utils import (get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async)
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
......@@ -22,46 +27,173 @@ class CPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu"
assert self.lora_config is None, "cpu backend doesn't support LoRA"
#
# Environment variables for CPU executor
#
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
# Intel OpenMP setting
ld_prealod_str = os.getenv("LD_PRELOAD", "")
if "libiomp5.so" in ld_prealod_str:
# The time(milliseconds) that a thread should wait after
# completing the execution of a parallel region, before sleeping.
os.environ['KMP_BLOCKTIME'] = "1"
# Prevents the CPU to run into low performance state
os.environ['KMP_TPAUSE'] = "0"
# Provides fine granularity parallelism
os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
# To hint IPEX uses shared memory based AllReduce
os.environ["LOCAL_WORLD_SIZE"] = str(
self.parallel_config.tensor_parallel_size)
self.model_config = _verify_and_get_model_config(self.model_config)
self.cache_config = _verify_and_get_cache_config(self.cache_config)
self.scheduler_config = _verify_and_get_scheduler_config(
self.scheduler_config)
# Instantiate the worker and load the model to CPU.
self._init_worker()
def _init_worker(self):
from vllm.worker.cpu_worker import CPUWorker
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
ip = "127.0.0.1"
port = get_open_port()
self.distributed_init_method = get_distributed_init_method(ip, port)
is_async = isinstance(self, CPUExecutorAsync)
world_size = self.parallel_config.tensor_parallel_size
result_handler = ResultHandler()
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
self.workers = []
if is_async:
self.workers = [
ProcessWorkerWrapper(
result_handler,
partial(
self._create_worker,
rank=rank,
local_rank=rank,
)) for rank in range(0, world_size)
]
self.driver_worker = self.workers[0]
self.workers = self.workers[1:]
self.driver_method_invoker = _async_driver_method_invoker
else:
self.driver_worker = self._create_worker()
self.driver_method_invoker = _driver_method_invoker
if world_size != 1:
self.workers = [
ProcessWorkerWrapper(
result_handler,
partial(
self._create_worker,
rank=rank,
local_rank=rank,
)) for rank in range(1, world_size)
]
if world_size != 1 or is_async:
if is_async:
async_worker_list = self.workers + [self.driver_worker]
else:
async_worker_list = self.workers
self.worker_monitor = WorkerMonitor(async_worker_list,
result_handler)
result_handler.start()
self.worker_monitor.start()
self._run_workers("init_device")
self._run_workers("load_model")
def _create_worker(
self,
local_rank: int = 0,
rank: int = 0,
):
worker_module_name = "vllm.worker.cpu_worker"
worker_class_name = "CPUWorker"
wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
assert self.parallel_config.world_size == 1, (
"CPUExecutor only supports single CPU socket currently.")
assert self.distributed_init_method is not None
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = CPUWorker(
kwargs = dict(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
local_rank=local_rank,
rank=rank,
distributed_init_method=self.distributed_init_method,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=self.cache_config.cache_dtype,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=True,
is_driver_worker=rank == 0,
)
self.driver_worker.init_device()
self.driver_worker.load_model()
wrapper.init_worker(**kwargs)
return wrapper.worker
def _run_workers(
self,
method: str,
*args,
async_run_remote_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
"""
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
# Start the workers first.
worker_outputs = [
worker.execute_method(method, *args, **kwargs)
for worker in self.workers
]
if async_run_remote_workers_only:
# Just return futures
return worker_outputs
driver_worker_output = self.driver_method_invoker(
self.driver_worker, method, *args, **kwargs)
# Get the results of the workers.
return [driver_worker_output
] + [output.get() for output in worker_outputs]
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
return self.driver_method_invoker(self.driver_worker,
"determine_num_available_blocks")
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
......@@ -74,43 +206,95 @@ class CPUExecutor(ExecutorBase):
# referred as `gpu block`. Because we want to reuse the existing block
# management procedure.
logger.info("# CPU blocks: %d", num_gpu_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(execute_model_req)
if (self.parallel_config.tensor_parallel_size > 1
and self.parallel_worker_tasks is None):
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_remote_workers_only=True,
)
output = self.driver_method_invoker(self.driver_worker,
"execute_model", execute_model_req)
return output
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return
"""
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
self.driver_method_invoker(self.driver_worker, "execute_model", None)
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.driver_worker.add_lora(lora_request)
return all(self._run_workers("add_lora", lora_request))
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)
return all(self._run_workers("remove_lora", lora_id))
def pin_lora(self, lora_id: int) -> bool:
return self.driver_worker.pin_lora(lora_id)
assert lora_id > 0, "lora_id must be greater than 0."
return all(self._run_workers(
"pin_lora",
lora_id=lora_id,
))
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
return self.driver_method_invoker(self.driver_worker, "list_loras")
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
return all(
self._run_workers(
"add_prompt_adapter",
prompt_adapter_request,
))
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
return all(
self._run_workers(
"remove_prompt_adapter",
prompt_adapter_id,
))
def list_prompt_adapters(self) -> Set[int]:
return self.driver_worker.list_prompt_adapters()
return self.driver_method_invoker(self.driver_worker,
"list_prompt_adapters")
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
return all(self._run_workers(
"pin_prompt_adapter",
prompt_adapter_id,
))
def check_health(self) -> None:
# CPUExecutor will always be healthy as long as
# it's running.
return
"""Raises an error if engine is unhealthy."""
if self.worker_monitor is not None and not self.worker_monitor.is_alive(
):
raise RuntimeError("Worker processes are not running")
def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor",
None)) is not None:
worker_monitor.close()
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
for result in parallel_worker_tasks:
result.get()
class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
......@@ -118,14 +302,12 @@ class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model
output = await make_async(self.execute_model
)(execute_model_req=execute_model_req, )
return output
async def check_health_async(self) -> None:
# CPUExecutor will always be healthy as long as
# it's running.
return
self.check_health()
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
......@@ -170,3 +352,11 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
f" {kv_cache_space}, expect a positive integer value.")
return config
def _driver_method_invoker(driver, method: str, *args, **kwargs):
return getattr(driver, method)(*args, **kwargs)
def _async_driver_method_invoker(driver, method: str, *args, **kwargs):
return driver.execute_method(method, *args, **kwargs).get()
import asyncio
import os
import signal
import threading
import weakref
from functools import partial
from typing import Any, List, Optional
import torch
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.gpu_executor import create_worker
......@@ -14,7 +17,6 @@ from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.triton_utils import maybe_set_triton_cache_manager
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
error_on_invalid_device_count_status,
get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async,
update_environment_variables)
......@@ -44,10 +46,23 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
# Set OMP_NUM_THREADS to 1 if it is not set explicitly, avoids CPU
# contention amongst the shards
if "OMP_NUM_THREADS" not in os.environ:
os.environ["OMP_NUM_THREADS"] = "1"
# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
# Helps to avoid CPU contention. The default of spawning a thread per
# core combined with multiprocessing for each GPU can have a negative
# impact on performance. The contention is amplified when running in a
# container where CPU limits can cause throttling.
default_omp_num_threads = 1
if "OMP_NUM_THREADS" not in os.environ and (
current_parallelism :=
torch.get_num_threads()) > default_omp_num_threads:
logger.warning(
"Reducing Torch parallelism from %d threads to %d to avoid "
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
"external environment to tune this value as needed.",
current_parallelism, default_omp_num_threads)
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
torch.set_num_threads(default_omp_num_threads)
# workaround for https://github.com/vllm-project/vllm/issues/6103
if world_size > 1:
......@@ -63,8 +78,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
f"please ensure that world_size ({world_size}) "
f"is less than than max local gpu count ({cuda_device_count})")
error_on_invalid_device_count_status()
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
......@@ -115,8 +128,9 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
if executor := ref():
executor.shutdown()
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method)
......
......@@ -10,10 +10,9 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (_run_task_with_lock,
error_on_invalid_device_count_status,
get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, get_vllm_instance_id,
make_async)
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
......@@ -29,6 +28,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
uses_ray: bool = True
def _init_executor(self) -> None:
self.forward_dag: Optional["ray.dag.CompiledDAG"] = None
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
......@@ -60,8 +60,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
self.forward_dag: Optional["ray.dag.CompiledDAG"] = None
def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
......@@ -107,12 +105,19 @@ class RayGPUExecutor(DistributedGPUExecutor):
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []
# Used in ray compiled DAG: indexed first by PP rank,
# and then TP rank. In other words, the inner list is
# the TP group of workers for a PP rank.
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers.
driver_ip = get_ip()
logger.info("driver_ip: %s", driver_ip)
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
......@@ -144,42 +149,49 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Else, added to the list of workers.
self.workers.append(worker)
logger.debug("workers: %s", self.workers)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")
worker_ips = [
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
for worker in self.workers
]
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
def sort_by_driver_then_worker_ip(worker):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = ray.get(worker.get_node_ip.remote())
return (ip != driver_ip, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)
# the order in `worker_node_and_gpu_ids` does not necessarily match
# the machine boundaries. We need to make sure that workers in the
# same node are assigned consecutive ranks.
# examples:
# [('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [1]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [2]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [3]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [1]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [2]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [3])] # noqa
# initialize worker ranks with -1 (unassigned)
worker_ranks = [-1 for x in worker_node_and_gpu_ids]
current_rank = 0
while -1 in worker_ranks:
# whenever we find an unassigned worker, find the node
index = worker_ranks.index(-1)
current_node_id = worker_node_and_gpu_ids[index][0]
# assign ranks to all workers in the same node
for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
if node_id == current_node_id:
worker_ranks[i] = current_rank
current_rank += 1
# with the above example, worker_ranks will be [0, 4, 5, 6, 7, 1, 2, 3]
node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids
for worker_rank, (node_id, gpu_ids) in zip(worker_ranks,
worker_node_and_gpu_ids):
node_workers[node_id].append(worker_rank)
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
......@@ -217,16 +229,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
error_on_invalid_device_count_status()
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank),
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id,
_) in zip(worker_ranks, worker_node_and_gpu_ids)
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
......@@ -235,6 +244,19 @@ class RayGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
if self.use_ray_spmd_worker:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(
self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size
) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
......@@ -245,9 +267,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
self.non_driver_workers: List[RayWorkerWrapper] = []
# Enforce rank order for correct rank to return final output.
for rank, worker in sorted(zip(worker_ranks[1:], self.workers)):
# We need to skip the driver worker, which we
# do by skipping worker_ranks[0] which is always 0.
for index, worker in enumerate(self.workers):
# The driver worker is rank 0 and not in self.workers.
rank = index + 1
if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
else:
......@@ -380,16 +402,47 @@ class RayGPUExecutor(DistributedGPUExecutor):
raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.use_ray
from ray.dag import InputNode, MultiOutputNode
from ray.experimental.channel.torch_tensor_type import TorchTensorType
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL)
with InputNode() as input_data:
forward_dag = MultiOutputNode([
worker.execute_model_spmd.bind( # type: ignore[attr-defined]
input_data) for worker in self.workers
])
# Example DAG: PP=2, TP=4
# (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501
# -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501
# -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501
# -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501
# All workers in the first TP group will take in the
# ExecuteModelRequest as input.
outputs = [input_data for _ in self.pp_tp_workers[0]]
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
# Each PP worker takes in the output of the previous PP worker,
# and the TP group executes in SPMD fashion.
outputs = [
worker.execute_model_spmd.
bind( # type: ignore[attr-defined]
outputs[i]) for i, worker in enumerate(tp_group)
]
last_pp_rank = len(self.pp_tp_workers) - 1
if pp_rank < last_pp_rank:
# Specify how intermediate tensors should be passed
# between pp stages, no need to specify for the last
# pp stage.
transport = "nccl" \
if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \
else "auto"
outputs = [
output.with_type_hint(
TorchTensorType(transport=transport))
for output in outputs
]
forward_dag = MultiOutputNode(outputs)
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def __del__(self):
......
import asyncio
import os
from collections import defaultdict
from itertools import islice, repeat
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple,
Union)
import vllm.envs as envs
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.executor.tpu_executor import TPUExecutor
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
class RayTPUExecutor(TPUExecutor):
def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
# Updated by implementations that require additional args to be passed
# to the _run_workers execute_model call
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
super().__init__(*args, **kwargs)
def _init_executor(self) -> None:
assert self.parallel_config.distributed_executor_backend == "ray"
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
# Create the parallel TPU workers.
self._init_workers_ray(placement_group)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []
# Create the workers.
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("TPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
assert self.speculative_config is None
worker_module_name = "vllm.worker.tpu_worker"
worker_class_name = "TPUWorker"
worker = ray.remote(
num_cpus=0,
resources={"TPU": 1},
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
else:
# Else, added to the list of workers.
self.workers.append(worker)
if self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any TPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"TPU node.")
# Get the set of TPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)
node_workers = defaultdict(list)
for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
VLLM_INSTANCE_ID = get_vllm_instance_id()
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for _ in worker_node_and_gpu_ids]
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
if len(node_workers) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank),
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_method("execute_model",
execute_model_req)
def _run_workers(
self,
method: str,
*args,
async_run_remote_workers_only: bool = False,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
ways:
- async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than blocking
on the results.
- args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
if async_run_remote_workers_only:
# Just return futures
return ray_worker_outputs
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
method, *driver_args, **driver_kwargs)
else:
assert self.driver_dummy_worker is not None
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
# Get the results of the ray workers.
if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)
def determine_num_available_blocks(self) -> Tuple[int, int]:
num_blocks = self._run_workers("determine_num_available_blocks", )
num_tpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return num_tpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_remote_workers_only=True,
**self.extra_execute_model_run_workers_kwargs)
# Only the driver worker returns the sampling results.
return self._driver_execute_model(execute_model_req)
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return
self._driver_execute_model()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)
class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_method = make_async(self.driver_worker.execute_method)
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers
self.parallel_worker_tasks = asyncio.create_task(
self._start_worker_execution_loop())
# Only the driver worker returns the sampling results.
return await self._driver_execute_model_async(execute_model_req)
async def stop_remote_worker_execution_loop_async(self) -> None:
if self.parallel_worker_tasks is None:
return
await self._driver_execute_model_async()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
await parallel_worker_tasks
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_method("execute_model",
execute_model_req)
async def _start_worker_execution_loop(self):
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.workers
]
return await asyncio.gather(*coros)
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_ip, is_hip, is_xpu
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip, is_hip, is_tpu, is_xpu
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
......@@ -31,9 +31,17 @@ try:
gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids
def execute_model_spmd(self, execute_model_req: ExecuteModelRequest):
"""Used only when SPMD worker and compiled DAG are both
enabled."""
def execute_model_spmd(
self, req_or_tuple: Union[ExecuteModelRequest,
Tuple[ExecuteModelRequest,
IntermediateTensors]]):
"""Execute model in SPMD fashion: used only when SPMD worker and
compiled DAG are both enabled.
Args:
req_or_tuple: The request to execute the model, or a tuple
containing the request and intermediate tensors.
"""
# TODO(swang): This is needed right now because Ray aDAG executes
# on a background thread, so we need to reset torch's current
# device.
......@@ -42,7 +50,17 @@ try:
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
return self.worker._execute_model_spmd(execute_model_req)
if isinstance(req_or_tuple, tuple):
execute_model_req, intermediate_tensors = req_or_tuple
else:
execute_model_req = req_or_tuple
intermediate_tensors = None
output = self.worker._execute_model_spmd(execute_model_req,
intermediate_tensors)
if isinstance(output, IntermediateTensors):
return execute_model_req, output
return output
ray_import_err = None
......@@ -93,32 +111,38 @@ def initialize_ray_cluster(
# Placement group is already set.
return
device_str = "GPU" if not is_tpu() else "TPU"
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
# We are in a placement group
bundles = current_placement_group.bundle_specs
# Verify that we can use the placement group.
gpu_bundles = 0
device_bundles = 0
for bundle in bundles:
bundle_gpus = bundle.get("GPU", 0)
if bundle_gpus > 1:
bundle_devices = bundle.get(device_str, 0)
if bundle_devices > 1:
raise ValueError(
"Placement group bundle cannot have more than 1 GPU.")
if bundle_gpus:
gpu_bundles += 1
if parallel_config.world_size > gpu_bundles:
"Placement group bundle cannot have more than 1 "
f"{device_str}.")
if bundle_devices:
device_bundles += 1
if parallel_config.world_size > device_bundles:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs in the placement group.")
f"The number of required {device_str}s exceeds the total "
f"number of available {device_str}s in the placement group."
f"Required number of devices: {parallel_config.world_size}. "
f"Total number of devices: {device_bundles}.")
else:
num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
if parallel_config.world_size > num_gpus_in_cluster:
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
if parallel_config.world_size > num_devices_in_cluster:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs in the cluster.")
f"The number of required {device_str}s exceeds the total "
f"number of available {device_str}s in the placement group.")
# Create a new placement group
placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
placement_group_specs = ([{
device_str: 1
}] * parallel_config.world_size)
current_placement_group = ray.util.placement_group(
placement_group_specs)
# Wait until PG is ready - this will block until all
......
......@@ -14,7 +14,6 @@ from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
RowParallelLinearWithLoRA)
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level
if TYPE_CHECKING:
pass
......@@ -28,7 +27,7 @@ def _fully_sharded_can_replace(can_replace):
def dec(*args, **kwargs):
return (can_replace(*args, **kwargs)
and kwargs['lora_config'].fully_sharded_loras)
and kwargs["lora_config"].fully_sharded_loras)
return dec
......@@ -59,25 +58,30 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = torch.zeros(
(x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device,
)
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
bgmv(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
self.punica_wrapper.add_expand(output,
buffer,
self.lora_b_stacked,
add_input=True)
# now have column partitioned output
output = output.view(*out_orig_shape)
return output
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
......@@ -88,14 +92,14 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
)
def _mcp_apply(x, bias, layer):
def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
"""
MergedColumnParallelLinearWithShardedLoRA and
MergedQKVParallelLinearWithShardedLora share the same
MergedColumnParallelLinearWithShardedLoRA and
MergedQKVParallelLinearWithShardedLora share the same
LoRa weight application method.
The main difference is the step by shard_size for lora_b which can
vary for MergedQKVParallelLinearWithShardedLora but is constant for
vary for MergedQKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA.
"""
# expecting 2 for column parallel and 3 for qkv
......@@ -104,21 +108,27 @@ def _mcp_apply(x, bias, layer):
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device)
buffers = torch.zeros(
(n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
for idx in range(n):
bgmv(buffers[idx], x, layer.lora_a_stacked[idx],
layer.indices[:layer.indices_len[0]], 0, 1.0)
layer.punica_wrapper.add_shrink(buffers[idx], x,
layer.lora_a_stacked[idx], 1.0)
buffers = tensor_model_parallel_all_gather(buffers)
left_offset = 0
for idx in range(n):
shard_size = layer.lora_b_stacked[idx].shape[2]
dispatch_bgmv_low_level(output, buffers[idx],
layer.lora_b_stacked[idx],
layer.indices[:layer.indices_len[0]], 0, 1.0,
left_offset, shard_size)
layer.punica_wrapper.add_expand_slice(
output,
buffers[idx],
layer.lora_b_stacked[idx],
left_offset,
shard_size,
add_input=True,
)
left_offset += shard_size
output = output.view(*out_orig_shape)
......@@ -129,7 +139,7 @@ def _mcp_apply(x, bias, layer):
class MergedColumnParallelLinearWithShardedLoRA(
MergedColumnParallelLinearWithLoRA):
"""
Differs from MergedColumnParallelLinearWithLoRA by slicing the
Differs from MergedColumnParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
......@@ -145,7 +155,8 @@ class MergedColumnParallelLinearWithShardedLoRA(
lora_a = [
lora_a[0][:,
output_start_idx:output_start_idx + output_shard_size],
lora_a[1][:, output_start_idx:output_start_idx + output_shard_size]
lora_a[1][:,
output_start_idx:output_start_idx + output_shard_size],
]
return lora_a
......@@ -155,9 +166,13 @@ class MergedColumnParallelLinearWithShardedLoRA(
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
......@@ -170,7 +185,7 @@ class MergedColumnParallelLinearWithShardedLoRA(
class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
"""
Differs from QKVParallelLinearWithLora by slicing the
Differs from QKVParallelLinearWithLora by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
......@@ -193,14 +208,13 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
bgmv(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
self.punica_wrapper.add_expand(output,
buffer,
self.lora_b_stacked,
add_input=True)
# now have column partitioned output
output = output.view(*out_orig_shape)
return output
......@@ -237,7 +251,7 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
lora_a = [
lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]],
lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]],
lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]]
lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]],
]
return lora_a
......@@ -247,9 +261,13 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
......@@ -262,11 +280,11 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
"""
Differs from RowParallelLinearWithLoRA by slicing the
Differs from RowParallelLinearWithLoRA by slicing the
LoRA B's also.
Based on S-LoRA, slicing happens along the output dim.
This yields a combined partial sum from the row parallel base
This yields a combined partial sum from the row parallel base
layer and column partitioned output from the LoRA.
"""
......@@ -283,11 +301,13 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = torch.zeros(
(x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device,
)
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce
......@@ -298,18 +318,21 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
# reduced before being used
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size
dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0,
start_idx, shard_size)
self.punica_wrapper.add_expand_slice(output, buffer,
self.lora_b_stacked, start_idx,
shard_size)
output = output.view(*out_orig_shape)
return output
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
......
......@@ -17,10 +17,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
from vllm.distributed.utils import divide
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.lora.punica import PunicaWrapper
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import (
......@@ -55,88 +56,17 @@ def _not_fully_sharded_can_replace(can_replace):
"""
def dec(*args, **kwargs):
decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True
condition = (not kwargs['lora_config'].fully_sharded_loras
decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
condition = (not kwargs["lora_config"].fully_sharded_loras
if decorate else True)
return can_replace(*args, **kwargs) and condition
return dec
def _apply_lora(
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
indices: torch.Tensor,
output: torch.Tensor,
):
"""Applies lora to each input.
This method applies all loras to each input. It uses the
indices vector to determine which lora yields the
correct output. An index of -1 means no lora should be
applied. This method adds the final lora results to the
output.
Input shapes:
x: (batch_size, hidden_dim)
lora_a_stacked: (num_loras, lora_rank, hidden_dim)
lora_b_stacked: (num_loras, output_dim, lora_rank)
indices: (batch_size)
output: (batch_size, output_dim)
"""
org_output = output
x = x.view(-1, x.shape[-1])
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
return output.view_as(org_output)
def _apply_lora_packed_nslice(
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
):
"""Applies lora to each input.
This method applies all loras to each input. It uses the
indices vector to determine which lora yields the
correct output. An index of -1 means no lora should be
applied. This method adds the final lora results to the
output.
This method is used for layers that are composed of multiple sublayers
(slices) packed together.
Input shapes:
x: (batch_size, hidden_dim)
lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
x = x.view(-1, x.shape[-1])
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx in range(len(output_slices)):
add_lora_slice(output, x, lora_a_stacked[slice_idx],
lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
output_slices[slice_idx])
offset_left += output_slices[slice_idx]
return output.view_as(org_output)
@dataclass
class LoRAMapping(AdapterMapping):
pass
is_prefill: bool = False
class BaseLayerWithLoRA(nn.Module):
......@@ -154,10 +84,11 @@ class BaseLayerWithLoRA(nn.Module):
...
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
"""Initializes lora matrices."""
...
......@@ -177,20 +108,18 @@ class BaseLayerWithLoRA(nn.Module):
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
punica_wrapper: PunicaWrapper,
):
"""Sets the mapping indices."""
...
self.punica_wrapper: PunicaWrapper = punica_wrapper
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
raise NotImplementedError
......@@ -259,10 +188,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2],
)
# Lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
self.embeddings_indices: torch.Tensor
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
......@@ -285,40 +210,27 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
if embeddings_tensor is not None:
self.embeddings_tensors[
index, :embeddings_tensor.shape[0], :embeddings_tensor.
shape[1]].copy_(embeddings_tensor, non_blocking=True)
shape[1], ].copy_(embeddings_tensor, non_blocking=True)
if self.embeddings_slice is not None:
# TODO(yard1): Optimize this copy, we don't need to copy
# everything, just the modified part
embeddings = self.embeddings_tensors.view(
self.embeddings_tensors.shape[0] *
self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2]
self.embeddings_tensors.shape[2],
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
assert self.embeddings_weights is not None
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
self.embeddings_indices = embeddings_indices
self.indices_len = indices_len
def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
embedding_len = self.indices_len[3]
indices = self.embeddings_indices[1][:embedding_len].view_as(x)
embeddings_indices = self.punica_wrapper.embeddings_indices
indices = embeddings_indices[1].view_as(x)
full_lora_a_embeddings = F.embedding(
x + indices,
self.lora_a_stacked_2d,
)
indices = self.embeddings_indices[0][:embedding_len].view_as(x)
indices = embeddings_indices[0].view_as(x)
full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask))
......@@ -329,22 +241,125 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
if full_lora_a_embeddings.ndim == 3:
full_lora_a_embeddings = full_lora_a_embeddings.view(
full_lora_a_embeddings.shape[0] *
full_lora_a_embeddings.shape[1], -1)
bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
full_lora_a_embeddings.shape[1],
-1,
)
# Embedding layer only need expand op
self.punica_wrapper.add_expand(full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)
return full_output.view_as(full_output_org)
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is VocabParallelEmbedding
class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: ReplicatedLinear) -> None:
super().__init__()
self.base_layer = base_layer
self.input_size = self.base_layer.input_size
self.output_size = self.base_layer.output_size
self.device = _get_lora_device(self.base_layer)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
self.lora_config = lora_config
lora_a_output_size = lora_config.max_lora_rank
self.lora_a_stacked = torch.zeros(
max_loras,
1,
lora_a_output_size,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
self.lora_b_stacked = torch.zeros(
max_loras,
1,
self.output_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
)
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
return output
def forward(self, input_):
"""Forward of ReplicatedLinearWithLoRA
Args:
input_: Tensor whose last dimension is `input_size`.
Returns:
- output
- bias
"""
bias = (self.base_layer.bias
if not self.base_layer.skip_bias_add else None)
# Matrix multiply.
output = self.apply(input_, bias)
output_bias = (self.base_layer.bias
if self.base_layer.skip_bias_add else None)
return output, output_bias
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is ReplicatedLinear
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
"""
LoRA on top of ColumnParallelLinear layer.
LoRA B is sliced for tensor parallelism.
"""
......@@ -357,10 +372,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self.device = _get_lora_device(self.base_layer)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
self.lora_config = lora_config
self.tp_size = get_tensor_model_parallel_world_size()
lora_a_output_size_per_partition = (
......@@ -384,10 +400,6 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
)
self.output_dim = self.lora_b_stacked.shape[2]
# lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
......@@ -423,28 +435,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
self.indices_len = indices_len
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
return output
def forward(self, input_):
......@@ -473,9 +468,13 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is ColumnParallelLinear or (
type(source_layer) is MergedColumnParallelLinear
and len(packed_modules_list) == 1)
......@@ -494,10 +493,11 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
super().__init__(base_layer)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
self.lora_config = lora_config
n_slices = 2
if not (len(self.base_layer.output_sizes) == n_slices
......@@ -533,8 +533,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
) for _ in range(n_slices))
self.output_dim = self.lora_b_stacked[0].shape[2]
# Lazily initialized.
self.indices: torch.Tensor
def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
......@@ -556,7 +554,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = [
lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx]
lora_b[0][:, start_idx:end_idx],
lora_b[1][:, start_idx:end_idx],
]
return lora_b
......@@ -591,34 +590,33 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
(self.output_dim, self.output_dim),
)
self.punica_wrapper.add_lora_packed_nslice(
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
(self.output_dim, self.output_dim))
return output
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is MergedColumnParallelLinear and len(
packed_modules_list) == 2
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return (type(source_layer) is MergedColumnParallelLinear
and len(packed_modules_list) == 2)
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
"""
ColumnParallelLinear layer that is specifically designed for
qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
only contains a single LoRA within their qkv_proj layer.
ColumnParallelLinear layer that is specifically designed for
qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
only contains a single LoRA within their qkv_proj layer.
During inference with Tensor Parallel, the weights of lora_b
During inference with Tensor Parallel, the weights of lora_b
must be accurately partitioned according to the respective ranks.
Q slice may have different shape than K and V slices (which both have
the same shape).
"""
......@@ -696,10 +694,11 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
super().__init__(base_layer)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
self.lora_config = lora_config
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
......@@ -767,11 +766,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
),
)
self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
self.kv_proj_shard_size)
self.output_slices = (
self.q_proj_shard_size,
self.kv_proj_shard_size,
self.kv_proj_shard_size,
)
self.packed_indices: Optional[torch.Tensor] = None
self.standard_indices: Optional[torch.Tensor] = None
# lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
def reset_lora(self, index: int):
......@@ -794,15 +797,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
(self.q_shard_id + 1), ]
if lora_b[1] is not None:
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
(self.kv_shard_id + 1), ]
if lora_b[2] is not None:
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
(self.kv_shard_id + 1), ]
lora_b = [lora_b_q, lora_b_k, lora_b_v]
return lora_b
......@@ -851,23 +854,23 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
self.output_slices,
)
self.punica_wrapper.add_lora_packed_nslice(output, x,
self.lora_a_stacked,
self.lora_b_stacked, 1.0,
self.output_slices)
return output
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is QKVParallelLinear and len(
packed_modules_list) == 3
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return (type(source_layer) is QKVParallelLinear
and len(packed_modules_list) == 3)
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
......@@ -880,10 +883,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.device = _get_lora_device(self.base_layer)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
self.lora_config = lora_config
self.tp_rank = get_tensor_model_parallel_rank()
self.lora_a_stacked = torch.zeros(
......@@ -911,9 +915,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype,
device=self.device,
)
# Lazily initialized
self.indices: torch.Tensor
self.indices_len: List[int]
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
......@@ -950,27 +951,10 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
self.indices_len = indices_len
def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)
_apply_lora(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
return output
def forward(self, input_):
......@@ -1013,14 +997,18 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
@property
def weight(self):
return self.base_layer.weight if hasattr(
self.base_layer, "weight") else self.base_layer.qweight
return (self.base_layer.weight if hasattr(self.base_layer, "weight")
else self.base_layer.qweight)
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is RowParallelLinear
......@@ -1067,6 +1055,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def soft_cap(self):
return self.base_layer.soft_cap
@property
def use_gather(self):
return self.base_layer.use_gather
@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size
......@@ -1081,7 +1073,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
# TODO: Verify if this condition can be relaxed
if 32000 < self.base_layer.vocab_size > 128512:
raise ValueError("When using LoRA, vocab size must be "
"32000 >= vocab_size <= 128512")
......@@ -1121,10 +1113,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
dtype=torch.long)
else:
self.sharded_to_full_mapping_gpu = None
# Lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
self.indices_padded: torch.Tensor
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
......@@ -1150,19 +1138,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
index, :embeddings_tensor.shape[0], :embeddings_tensor.
shape[1], ] = embeddings_tensor
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = sampler_indices
self.indices_padded = sampler_indices_padded
self.indices_len = indices_len
def _get_logits(
self,
hidden_states: torch.Tensor,
......@@ -1208,38 +1183,37 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
out=lora_logits[:-1])
lora_logits[-1] = float("-inf")
lora_logits = lora_logits.mT
indices_padded = self.punica_wrapper.sampler_indices_padded
lora_logits = (lora_logits.reshape(
lora_logits.shape[0] * lora_logits.shape[1],
lora_logits.shape[2],
).index_select(0,
self.indices_padded[:self.indices_len[2]]).nan_to_num_(
nan=float("-inf"),
posinf=float("inf"),
neginf=float("-inf")))
).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
posinf=float("inf"),
neginf=float("-inf")))
logits[:,
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1]] = lora_logits
_apply_lora(
hidden_states,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[1]],
logits,
)
lora_logits.shape[1], ] = lora_logits
# LogitsProcessorWithLoRA always using bgmv
self.punica_wrapper.add_lora_logits(logits, hidden_states,
self.lora_a_stacked,
self.lora_b_stacked, 1.0)
# Remove paddings in vocab (if any).
logits = logits[:, :self.base_layer.vocab_size]
return logits
def forward(self, *args, **kwargs):
return type(self.base_layer).forward(self, *args, **kwargs)
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
# Special handling for the LogitsProcessor.
return False
......@@ -1255,9 +1229,6 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
def __init__(self, base_layer: RotaryEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
# Lazily initialized
self.long_lora_indices: torch.Tensor
self.indices_len: List[int]
@property
def scaling_factors(self):
......@@ -1273,9 +1244,8 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
scaling_factors = list(
lora_config.long_lora_scaling_factors
) if lora_config.long_lora_scaling_factors else []
scaling_factors = (list(lora_config.long_lora_scaling_factors)
if lora_config.long_lora_scaling_factors else [])
base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
scaling_factors = sorted(
......@@ -1302,18 +1272,6 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
):
...
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.long_lora_indices = long_lora_indices
self.indices_len = indices_len
def forward(
self,
positions: torch.Tensor,
......@@ -1324,19 +1282,24 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
positions,
query,
key,
offsets=self.long_lora_indices[:self.indices_len[4]])
offsets=self.punica_wrapper.long_lora_indices,
)
@property
def scaling_factor_to_offset(self) -> Dict[float, int]:
return self.base_layer.scaling_factor_to_offset
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
return type(source_layer) is LinearScalingRotaryEmbedding or type(
source_layer) is RotaryEmbedding
return (type(source_layer) is LinearScalingRotaryEmbedding
or type(source_layer) is RotaryEmbedding)
def extra_repr(self) -> str:
return self.base_layer.extra_repr()
......@@ -4,7 +4,7 @@ import math
import os
import re
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type
import safetensors.torch
import torch
......@@ -21,6 +21,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.punica import PunicaWrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA
......@@ -43,115 +44,6 @@ class LongContextLoRAContext:
offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
def convert_mapping(
mapping: LoRAMapping,
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional[LongContextLoRAContext] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], List[int]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors.
Used to index into each tensor. It contains length for
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices). If long_lora doesn't
exist, it only contains first 4 entries.
"""
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device="cuda",
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices, lora_indices, embedding_indices
]
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size)
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = (
torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
long_lora_indices = indices[3]
long_lora_indices_len = long_lora_indices.shape[-1]
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
]
if long_lora_indices_len is not None:
indices_len.append(long_lora_indices_len)
return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices, indices_len)
def get_lora_id():
global _GLOBAL_LORA_ID
_GLOBAL_LORA_ID += 1
......@@ -422,29 +314,12 @@ class LoRAModelManager(AdapterModelManager):
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None
self.base_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.sampler_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.embeddings_indices = torch.empty(2,
self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.long_lora_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
max_batches=self.max_num_seqs,
device="cuda")
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {}
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
self.indices_len: List[Optional[int]] = [None] * 4
super().__init__(model)
if hasattr(self.model, "supported_lora_modules"):
self.supported_lora_modules = copy.deepcopy(
......@@ -536,28 +411,16 @@ class LoRAModelManager(AdapterModelManager):
"Pinning is not supported in LoRAModelManager."
"Use LRUCacheLoRAModelManager for pinning") # type: ignore
# TODO see if this can be vectorized
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_offsets_tensor,
indices_len) = convert_mapping(mapping, self.lora_index_to_id,
self.lora_slots + 1, self.vocab_size,
self.lora_config.lora_extra_vocab_size,
self.long_lora_context)
self.base_indices[:base_indices.shape[0]].copy_(base_indices)
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
sampler_indices_padded)
self.embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices)
if long_lora_offsets_tensor is not None:
self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
long_lora_offsets_tensor)
else:
self.long_lora_indices.zero_()
# Maintain the reference
self.indices_len[:] = indices_len
# update lora states
self.punica_wrapper.update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
self.long_lora_context,
)
def remove_all_adapters(self):
"""Remove all LoRAModels from the manager."""
......@@ -595,10 +458,8 @@ class LoRAModelManager(AdapterModelManager):
self.model.config))
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
new_module.set_mapping(self.base_indices, self.sampler_indices,
self.sampler_indices_padded,
self.embeddings_indices,
self.long_lora_indices, self.indices_len)
# All lora layers share the same punica_wrapper based on reference.
new_module.set_mapping(self.punica_wrapper)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA)
......
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import Dict, Optional
import torch
import triton
import triton.language as tl
from .utils import get_lora_op_configs
@triton.jit
def _bgmv_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
lora_indices,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SPLIT_N: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
performance
"""
pid_sn = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
offset_k = tl.arange(0, BLOCK_K)
offset_n = tl.arange(0, BLOCK_N)
if EVEN_K:
tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
offset_k * xk_stride, ) # [BLOCK_K]
else:
tiled_a = tl.load(
input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
mask=offset_k < K,
other=0,
) # [BLOCK_K]
# N must be divisible by SPLIT_N
split_n_length = tl.cdiv(N, SPLIT_N)
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
# sliding to next row-block
b_ptr = (lora_ptr + l0_stride * lora_index +
pid_sn * split_n_length * lora_k_stride)
c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length
for n in range(0, split_n_length, BLOCK_N):
current_n = n + offset_n
current_n_c = tl.max_contiguous(current_n, BLOCK_N)
b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
< K)
c_mask = current_n < split_n_length
tiled_b = tl.load(
b_ptr + current_n_c[:, None] * lora_k_stride +
offset_k[None, :] * lora_n_stride,
mask=b_ptr_mask,
other=0.0,
) # [BLOCK_N,BLOCK_K]
if ADD_INPUTS:
tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
else:
accumulator = tl.sum(tiled_a * tiled_b, 1)
tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
@torch.inference_mode()
def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
override_config: Optional[Dict[str, int]] = None,
):
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
applied.
batches (int): batch size
add_inputs (bool, optional): Defaults to False. adds the final lora
results to the output.
override_config (Optional[Dict[str, int]], optional): Defaults to None.
Triton grid config
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_b_weights.size(-1)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
BLOCK_K = triton.next_power_of_2(K)
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
batches = lora_indices_tensor.size(0)
if override_config:
config = override_config
else:
config = get_lora_op_configs("expand", batches, N)
grid = lambda META: (
META["SPLIT_N"],
batches,
)
_bgmv_expand_kernel[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_K=BLOCK_K,
EVEN_K=EVEN_K,
ADD_INPUTS=ADD_INPUTS,
CAST_TYPE=CAST_TYPE,
**config,
)
return
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import Dict, Optional
import torch
import triton
import triton.language as tl
from .utils import get_lora_op_configs
@triton.jit
def _bgmv_expand_slice_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
lora_indices,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
slice_offset,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SPLIT_N: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
performance
"""
pid_sn = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
offset_k = tl.arange(0, BLOCK_K)
offset_n = tl.arange(0, BLOCK_N)
if EVEN_K:
tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
offset_k * xk_stride, ) # [BLOCK_K]
else:
tiled_a = tl.load(
input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
mask=offset_k < K,
other=0,
) # [BLOCK_K]
# N must be divisible by SPLIT_N
split_n_length = tl.cdiv(N, SPLIT_N)
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
# sliding to next row-block
b_ptr = (lora_ptr + l0_stride * lora_index +
pid_sn * split_n_length * lora_k_stride)
c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length +
slice_offset * cn_stride)
for n in range(0, split_n_length, BLOCK_N):
current_n = n + offset_n
b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
< K)
c_mask = current_n < split_n_length
tiled_b = tl.load(
b_ptr + current_n[:, None] * lora_k_stride +
offset_k[None, :] * lora_n_stride,
mask=b_ptr_mask,
other=0.0,
) # [BLOCK_N,BLOCK_K]
if ADD_INPUTS:
tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
else:
accumulator = tl.sum(tiled_a * tiled_b, 1)
tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
@torch.inference_mode()
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
override_config: Optional[Dict[str, int]] = None,
):
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'b weight
output_tensor (torch.Tensor): output tensor
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
applied.
slice_offst (int): output_tensor's offst
slice_size (int): current output_tensor's size
batches (int): batch size
add_inputs (bool, optional): Defaults to False.
override_config (Optional[Dict[str, int]], optional): Defaults to None.
Triton grid config
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_b_weights.size(-1)
assert slice_size == lora_b_weights.size(-2)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
BLOCK_K = triton.next_power_of_2(K)
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
batches = lora_indices_tensor.size(0)
if override_config:
config = override_config
else:
config = get_lora_op_configs("expand", batches, N)
grid = lambda META: (
META["SPLIT_N"],
batches,
)
_bgmv_expand_slice_kernel[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
slice_offset,
BLOCK_K=BLOCK_K,
EVEN_K=EVEN_K,
ADD_INPUTS=ADD_INPUTS,
CAST_TYPE=CAST_TYPE,
**config,
)
return
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import Dict, Optional
import torch
import triton
import triton.language as tl
from .utils import get_lora_op_configs
@triton.jit
def _bgmv_shrink_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
lora_indices,
scaling,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SPLIT_K: tl.constexpr,
):
"""
GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's
performance
"""
pid_sk = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
offset_n = tl.arange(0, BLOCK_N)
offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K
a_ptr = input_ptr + cur_batch * xm_stride
b_ptr = lora_ptr + l0_stride * lora_index
accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32)
for k in range(0, K, BLOCK_K * SPLIT_K):
current_k = k + offset_k
current_k_c = tl.max_contiguous(current_k, BLOCK_K)
tiled_a = tl.load(
a_ptr + current_k_c,
mask=current_k < K,
other=0.0,
) # [BLOCK_K]
b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K)
tiled_b = tl.load(
b_ptr + offset_n[:, None] * lora_k_stride +
current_k[None, :] * lora_n_stride,
mask=b_ptr_mask,
other=0.0,
) # [BLOCK_N,BLOCK_K]
accumulator += tl.sum(tiled_a * tiled_b, 1)
accumulator *= scaling
offset_cn = tl.arange(0, BLOCK_N)
c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride
c_mask = offset_cn < N
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
@torch.inference_mode()
def bgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
override_config: Optional[Dict[str, int]] = None,
):
"""
Args:
inputs (torch.Tensor): input tensor
lora_a_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
scaling (float): Scaling factor.
override_config (Optional[Dict[str, int]], optional): Defaults to None.
Triton grid config
"""
assert inputs.dtype == lora_a_weights.dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
assert lora_a_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_a_weights.size(-1)
assert inputs.is_contiguous()
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
assert lora_a_weights.size(1) == 1
lora_a_weights = lora_a_weights.squeeze(dim=1)
else:
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
assert lora_a_weights.is_contiguous()
assert output_tensor.is_contiguous()
# TODO tuning this config
batches = lora_indices_tensor.size(0)
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
BLOCK_N = triton.next_power_of_2(N)
if override_config:
config = override_config
else:
# First try to load optimal config from the file
config = get_lora_op_configs("bgmv_shrink", batches, K)
grid = lambda META: (
META["SPLIT_K"],
batches,
)
_bgmv_shrink_kernel[grid](
inputs,
lora_a_weights,
output_tensor,
N,
K,
lora_indices_tensor,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_a_weights.stride(0),
lora_a_weights.stride(1),
lora_a_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_N=BLOCK_N,
**config,
)
return
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
from vllm.triton_utils import libentry
@libentry()
@triton.jit
def _sgmv_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
xm_stride,
xk_stride, # 1
l0_stride, # hidden_size*max_rank
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
The sgmv's expand triton kernel is based on GroupGEMM.
"""
pid = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = tl.arange(0, BLOCK_K)
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
offset_k[None, :] * xk_stride, )
b_ptr = (lora_ptr + l0_stride * lora_index +
offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K)):
if EVEN_K:
tiled_a = tl.load(a_ptr)
tiled_b = tl.load(b_ptr)
else:
tiled_a = tl.load(a_ptr,
mask=offset_k[None, :] < K - k * BLOCK_K,
other=0)
tiled_b = tl.load(b_ptr,
mask=offset_k[:, None] < K - k * BLOCK_K,
other=0)
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * xk_stride
b_ptr += BLOCK_K * lora_n_stride
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
offset_cn[None, :] * cn_stride)
M = tl.load(seq_lens + cur_batch)
c_mask = (offset_cm[:, None] <
(cur_seq_start + M)) & (offset_cn[None, :] < N)
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@torch.inference_mode()
def sgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
add_inputs: bool = False,
):
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
add_inputs (bool, optional): Defaults to False. adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
BLOCK_M = 32
BLOCK_N = 32
BLOCK_K = 16
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
batches,
)
_sgmv_expand_kernel[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
)
return
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
from vllm.triton_utils import libentry
@libentry()
@triton.jit
def _sgmv_expand_slice_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
xm_stride,
xk_stride, # 1
l0_stride, # hidden_size*max_rank
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
slice_offset,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
Similar to the 'sgmv_expand' operator, but with an added parameter
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
might be that in the future, we could implement a fusion operator to
achieve the current functionality instead of having to call it multiple
times.
"""
pid = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = tl.arange(0, BLOCK_K)
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
offset_k[None, :] * xk_stride, )
b_ptr = (lora_ptr + l0_stride * lora_index +
offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K)):
if EVEN_K:
tiled_a = tl.load(a_ptr)
tiled_b = tl.load(b_ptr)
else:
tiled_a = tl.load(a_ptr,
mask=offset_k[None, :] < K - k * BLOCK_K,
other=0)
tiled_b = tl.load(b_ptr,
mask=offset_k[:, None] < K - k * BLOCK_K,
other=0)
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * xk_stride
b_ptr += BLOCK_K * lora_n_stride
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
offset_cn[None, :] * cn_stride)
M = tl.load(seq_lens + cur_batch)
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
(slice_offset + N))
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@torch.inference_mode()
def sgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
):
"""_summary_
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
slice_offst (int): output_tensor's offst
slice_size (int): current output_tensor's size
add_inputs (bool, optional): Defaults to False. adds the final lora
results to the output..
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert slice_size == lora_b_weights.size(-2)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
BLOCK_M = 32
BLOCK_N = 32
BLOCK_K = 16
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
batches,
)
_sgmv_expand_slice_kernel[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
slice_offset,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
)
return
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
from vllm.triton_utils import libentry
@libentry()
@triton.jit
def _sgmv_shrink_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
scaling,
xm_stride, # hidden_size
xk_stride, # 1
l0_stride, # hidden_size*max_rank
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
):
"""
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
introducing SPLIT-K can improve performance
"""
pid = tl.program_id(axis=0)
pid_sk = tl.program_id(axis=1)
cur_batch = tl.program_id(axis=2)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
offset_k[None, :] * xk_stride)
b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride +
offset_k[:, None] * lora_n_stride)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
tiled_a = tl.load(a_ptr)
tiled_b = tl.load(b_ptr)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
tiled_a = tl.load(a_ptr,
mask=offset_k[None, :] < k_remaining,
other=0.0)
tiled_b = tl.load(b_ptr,
mask=offset_k[:, None] < k_remaining,
other=0.0)
accumulator += tl.dot(tiled_a, tiled_b)
a_ptr += BLOCK_K * SPLIT_K * xk_stride
b_ptr += BLOCK_K * SPLIT_K * lora_n_stride
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
offset_cn[None, :] * cn_stride)
c_mask = (offset_cm[:, None] <
(cur_seq_start + M)) & (offset_cn[None, :] < N)
accumulator *= scaling
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
@torch.inference_mode()
def sgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
scaling: float,
):
"""
Args:
inputs (torch.Tensor): input tensor
lora_a_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
scaling (float): Scaling factor.
"""
assert inputs.dtype == lora_a_weights.dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
assert lora_a_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_a_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert inputs.is_contiguous()
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
assert lora_a_weights.size(1) == 1
lora_a_weights = lora_a_weights.squeeze(dim=1)
else:
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
assert lora_a_weights.is_contiguous()
assert output_tensor.is_contiguous()
# TODO tuning this config
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
BLOCK_M = 32
BLOCK_N = 16
BLOCK_K = 32
SPLIT_K = 8
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
SPLIT_K,
batches,
)
_sgmv_shrink_kernel[grid](
inputs,
lora_a_weights,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_a_weights.stride(0),
lora_a_weights.stride(1),
lora_a_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
)
return
import functools
from typing import Dict
@functools.lru_cache
def _get_op_configs(op_type: str, batch: int, hidden_size: int):
# TODO: add optimal configurations
return None
def _check_divisibility(hidden_size: int):
# The bgmv_expand kernel requires that the hidden_size be divisible by
# the number below.
divisibility = [2, 4, 8, 16, 32, 64]
divisibility.sort(reverse=True)
for div in divisibility:
if hidden_size % div == 0:
return div
# hidden_size is an odd number
return 1
def _get_default_config(op_type: str, batch: int, hidden_size: int):
if op_type == "expand":
return {
"BLOCK_N": 256,
"SPLIT_N": _check_divisibility(hidden_size),
"num_warps": 8
}
else:
return {"BLOCK_K": 256, "SPLIT_K": 64, "num_warps": 8}
def get_lora_op_configs(op_type: str, batch: int,
hidden_size: int) -> Dict[str, int]:
"""Inspired by `fused_moe_kernel`
The return value will be a dictionary mapping an irregular grid of batch
sizes and hidden_size to configurations of the bgmv-related kernel.
NOTE: It currently only supports the default configuration. We plan to
generate optimal configurations for different hardware in the future using
scripts similar to `benchmark_moe.py`.
"""
config = _get_op_configs(op_type, batch, hidden_size)
if not config:
config = _get_default_config(op_type, batch, hidden_size)
return config
# Based on code from https://github.com/punica-ai/punica
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import Optional
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
def _check_punica_support():
if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"):
return
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
if current_platform.get_device_capability() < (8, 0):
raise ImportError(
"punica LoRA kernels require compute capability >= 8.0")
else:
raise ImportError(
"punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
"was set.")
def bgmv(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
):
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
matrices.
indicies: Shape: `[B]`. Indices of the weight matrices.
layer_idx: Layer index of the weight matrices.
scale: Scaling factor.
def compute_meta(
token_lora_tensor: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
will combine them into a single request, improving sgmv kernel inference
performance.
2. At the beginning of each prefill stage inference, recalculations are
needed based on the input, but only once.
"""
_check_punica_support()
ops.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
token_lora_tensor, return_counts=True)
cum_result = torch.cumsum(seq_length_tensor, dim=0)
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
b_seq_start_tensor[1:].copy_(cum_result[:-1])
max_length = seq_length_tensor.max().item()
batch_size = lora_indices_tensor.size(0)
no_lora = False
# -1 means no lora should be applied. Use `no_lora` to determine whether
# the current step requires LoRA. If LoRA is not needed, the prefill stage
# does not need to launch the triton kernel, which can improve performance
if batch_size == 1 and lora_indices_tensor == -1:
no_lora = True
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, no_lora)
def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, indicies: torch.LongTensor,
layer_idx: int, scale: float, y_offset: int,
y_slice_size: int):
"""
Same as `bgmv` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
# TODO see if this can be vectorized
def convert_mapping(
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], List[int]]:
"""Converts LoRAMapping to index tensors.
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
all of the transposed LoRA matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors. It contains
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices).
"""
_check_punica_support()
ops.dispatch_bgmv_low_level(
y,
x,
w_t_all,
indicies,
layer_idx,
scale,
x.size(1),
y_slice_size,
y_offset,
)
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device="cuda",
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices,
lora_indices,
embedding_indices,
]
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size),
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (
sampler_indices_padded * len(sampler_indices_padded))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
long_lora_indices = indices[3]
long_lora_indices_len = long_lora_indices.shape[-1]
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1],
sampler_indices.shape[-1],
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1],
]
if long_lora_indices_len is not None:
indices_len.append(long_lora_indices_len)
else:
# If long_lora doesn't exist,append None
indices_len.append(None)
def add_lora(y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
*,
buffer: Optional[torch.Tensor] = None):
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
return (
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_indices,
indices_len,
)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
LoRA A matrices.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
class PunicaWrapper:
"""
_check_punica_support()
r = wb_t_all.size(-1)
if buffer is None:
# We set the buffer to be float32 by default to avoid
# numerical inaccuracies that would otherwise happen
# due to downcasting.
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
ops.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
ops.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, scale)
def add_lora_slice(y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
y_offset: int,
y_slice_size: int,
*,
buffer: Optional[torch.Tensor] = None):
PunicaWrapper is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica kernel.
"""
Same as `add_lora` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: str):
self._token_lora_indices = torch.empty(max_num_batched_tokens,
dtype=torch.long,
device=device)
self._sampler_indices = torch.empty(max_num_batched_tokens,
dtype=torch.long,
device=device)
self._sampler_indices_padded = torch.empty(max_num_batched_tokens,
dtype=torch.long,
device=device)
self._embeddings_indices = torch.empty(2,
max_num_batched_tokens,
dtype=torch.long,
device=device)
self._long_lora_indices = torch.empty(max_num_batched_tokens,
dtype=torch.long,
device=device)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
LoRA A matrices.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
_check_punica_support()
r = wb_t_all.size(-1)
if buffer is None:
# We set the buffer to be float32 by default to avoid
# numerical inaccuracies that would otherwise happen
# due to downcasting.
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
ops.dispatch_bgmv_low_level(
buffer,
x,
wa_t_all,
indicies,
layer_idx,
1.0,
x.size(1),
buffer.size(1),
0,
)
ops.dispatch_bgmv_low_level(
y,
buffer,
wb_t_all,
indicies,
layer_idx,
scale,
buffer.size(1),
y_slice_size,
y_offset,
)
# 5 is the number of indicies tensors.
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices,long_lora_indices
self.indices_len: List[Optional[int]] = [None] * 5
# these attributes are the information required for sgmv kernel
self._seq_start_locs = torch.empty(max_batches,
dtype=torch.long,
device=device)
self._seq_lengths = torch.empty(max_batches,
dtype=torch.long,
device=device)
self._lora_indices_per_batch = torch.empty(max_batches,
dtype=torch.long,
device=device)
self.max_length: int = 0
self.batch_size: int = -1
self.is_prefill = False
self.no_lora = False
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
):
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
vocab_size, extra_vocab_size,
long_lora_context)
if mapping.is_prefill:
# Update metadata required for prefill-related operators.
self._update_prefill_metada(self.token_lora_indices)
self.is_prefill = True
else:
self.is_prefill = False
def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
):
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_offsets_tensor,
indices_len,
) = convert_mapping(
mapping,
lora_index_to_id,
max_loras,
vocab_size,
extra_vocab_size,
long_lora_context,
)
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
sampler_indices_padded)
self._embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices)
if long_lora_offsets_tensor is not None:
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
long_lora_offsets_tensor)
else:
self._long_lora_indices.zero_()
self.indices_len[:] = indices_len
def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
(b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, no_lora) = compute_meta(token_lora_tensor)
self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
b_seq_start_tensor)
self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor)
self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_(
lora_indices_tensor)
self.batch_size = batch_size
self.max_length = max_length
self.no_lora = no_lora
@property
def prefill_metadata(
self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
"""
This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations.
1. seq_start_locs: Tensor of sequence start positions
2. seq_lengths: Tensor of sequence lengths
3. lora_indices_per_batch: Tensor of lora indices, and an index of
-1 means no lora should be applied.
4. batch_size: batch size after clustering identical lora indices
5. max_length: The maximum sequence length in the batch
"""
return (self._seq_start_locs[:self.batch_size],
self._seq_lengths[:self.batch_size],
self._lora_indices_per_batch[:self.batch_size],
self.batch_size, self.max_length)
@property
def token_lora_indices(self) -> torch.Tensor:
"""
This property provides the lora indices corresponding to each token
in the batch. An index of -1 means no lora should be applied.
"""
token_lora_len = self.indices_len[0]
return self._token_lora_indices[:token_lora_len]
@property
def sampler_indices(self) -> torch.Tensor:
"""
This property is used to access the lora indices specifically for
LogitsProcessorWithLoRA
"""
sampler_indices_len = self.indices_len[1]
return self._sampler_indices[:sampler_indices_len]
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices
"""
indices_padded_len = self.indices_len[2]
return self._sampler_indices_padded[:indices_padded_len]
@property
def embeddings_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA
"""
embeddings_indices_len = self.indices_len[3]
return self._embeddings_indices[:, :embeddings_indices_len]
@property
def long_lora_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLora
"""
long_lora_len = self.indices_len[4]
return self._long_lora_indices[:long_lora_len]
def shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
def shrink_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand(
x,
w_t_all,
y,
*self.prefill_metadata,
add_input,
)
def expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool,
):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
def expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_input,
)
def expand_slice_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool,
):
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_input)
def add_shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the shrink_decode function
should be called.
"""
shrink_fun: Callable = (self.shrink_prefill
if self.is_prefill else self.shrink_decode)
shrink_fun(y, x, w_t_all, scale)
def add_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool = True,
):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'b.
When `is_prefill` is true, it indicates that it is currently the
prefill stage, and the `expand_prefill` function should be called.
Otherwise, it is the decode stage, and the expand_decode function
should be called.
"""
expand_fun: Callable = (self.expand_prefill
if self.is_prefill else self.expand_decode)
expand_fun(y, x, w_t_all, add_input)
def add_expand_slice(self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool = True):
"""
Similar to `add_expand`
"""
expand_slice_fun: Callable = (self.expand_slice_prefill
if self.is_prefill else
self.expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
def add_lora(self,
y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
scale: float,
y_offset: Optional[int] = None,
y_slice_size: Optional[int] = None,
*,
buffer: Optional[torch.Tensor] = None) -> None:
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
wa_t_all (torch.Tensor): lora_a's weight
wb_t_all (torch.Tensor): lora_b's weight
scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting
column of y.
y_slice_size (Optional[int], optional): Size of the y column slice..
buffer (Optional[torch.Tensor], optional): Defaults to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = wb_t_all.size(-1)
if buffer is None:
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
self.add_shrink(buffer, x, wa_t_all, scale)
if y_offset is None and y_slice_size is None:
self.add_expand(y, buffer, wb_t_all, add_input=True)
else:
self.add_expand_slice(y,
buffer,
wb_t_all,
y_offset,
y_slice_size,
add_input=True)
y = y.view_as(y_org)
def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor],
scale: float,
output_slices: Tuple[int, ...]) -> None:
"""
Applies lora to each input. Similar to add_lora, This method is
used for layers that are composed of multiple sublayers
(slices) packed together.
"""
y_org = y
x = x.view(-1, x.shape[-1])
y = y.view(-1, y.shape[-1])
offset_left = 0
# TODO fuse these kernels
for slice_idx in range(len(output_slices)):
self.add_lora(y, x, lora_a_stacked[slice_idx],
lora_b_stacked[slice_idx], scale, offset_left,
output_slices[slice_idx])
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None) -> None:
"""
LogitsProcessorWithLoRA always using bgmv
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = wb_t_all.size(-1)
if buffer is None:
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale)
bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True)
y = y.view_as(y_org)
......@@ -23,6 +23,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
# yapf: enable
......@@ -38,6 +39,7 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLora,
RowParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA,
QKVParallelLinearWithShardedLora,
......
......@@ -3,9 +3,10 @@ from typing import Optional, Union
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
get_lm_format_enforcer_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor
......@@ -20,6 +21,8 @@ async def get_guided_decoding_logits_processor(
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_lm_format_enforcer_guided_decoding_logits_processor)
return await get_lm_format_enforcer_guided_decoding_logits_processor(
request, tokenizer)
......@@ -28,6 +31,25 @@ async def get_guided_decoding_logits_processor(
"Must be one of 'outlines, 'lm-format-enforcer'")
def get_local_guided_decoding_logits_processor(
guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
tokenizer) -> Optional[LogitsProcessor]:
# request = _adapt_request_for_tool_use(request)
if guided_decoding_backend == 'outlines':
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")
def _adapt_request_for_tool_use(request: Union[CompletionRequest,
ChatCompletionRequest]):
# the legacy completion API does not support tool use
......
from dataclasses import dataclass
from typing import Dict, List, Optional, TypedDict, Union
from pydantic import BaseModel
class LLMGuidedOptions(TypedDict, total=False):
guided_json: Union[Dict, BaseModel, str]
guided_regex: str
guided_choice: List[str]
guided_grammar: str
guided_decoding_backend: str
guided_whitespace_pattern: str
guided_json_object: bool
@dataclass
class GuidedDecodingRequest:
"""One of the fields will be used to retrieve the logit processor."""
guided_json: Optional[Union[Dict, BaseModel, str]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
guided_grammar: Optional[str] = None
guided_decoding_backend: Optional[str] = None
guided_whitespace_pattern: Optional[str] = None
guided_json_object: Optional[bool] = None
def __post_init__(self):
"""Validate that some fields are mutually exclusive."""
guide_count = sum([
self.guided_json is not None, self.guided_regex is not None,
self.guided_choice is not None, self.guided_grammar is not None,
self.guided_json_object is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding but multiple are "
f"specified: {self.__dict__}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment