Unverified Commit ca871491 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc][LoRA] Abstract PunicaWrapper (#10955)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 3b61cb45
...@@ -28,7 +28,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, ...@@ -28,7 +28,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
# yapf: enable # yapf: enable
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
PackedLoRALayerWeights) PackedLoRALayerWeights)
from vllm.lora.punica import PunicaWrapper from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -48,11 +48,12 @@ TOLERANCES = { ...@@ -48,11 +48,12 @@ TOLERANCES = {
torch.float32: (5e-3, 5e-3), torch.float32: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2), torch.bfloat16: (3e-2, 2e-2),
} }
CUDA_DEVICES = [ # TODO: Modify this based on platform
DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] ]
# We will launch different triton kernels between the prefill and decode #For GPU, we will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False) # stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False] STAGES = [True, False]
...@@ -192,9 +193,18 @@ def create_random_inputs( ...@@ -192,9 +193,18 @@ def create_random_inputs(
return inputs, index_mapping, prompt_mapping return inputs, index_mapping, prompt_mapping
def check_punica_wrapper(punica_wrapper) -> bool:
if current_platform.is_cuda_alike():
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
return type(punica_wrapper) is PunicaWrapperGPU
else:
return False
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
@pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
...@@ -205,7 +215,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: ...@@ -205,7 +215,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
lora_dtype=torch.float16) lora_dtype=torch.float16)
...@@ -296,7 +307,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: ...@@ -296,7 +307,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
# @pytest.mark.skip( # @pytest.mark.skip(
# reason="Fails when loras are in any slot other than the first.") # reason="Fails when loras are in any slot other than the first.")
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
@pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("stage", STAGES)
def test_embeddings_with_new_embeddings(dist_init, num_loras, device, def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
...@@ -305,7 +316,8 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, ...@@ -305,7 +316,8 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
lora_dtype=torch.float16) lora_dtype=torch.float16)
...@@ -432,7 +444,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, ...@@ -432,7 +444,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
@pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("stage", STAGES)
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
...@@ -441,7 +453,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, ...@@ -441,7 +453,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
lora_dtype=torch.float16) lora_dtype=torch.float16)
...@@ -563,7 +576,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, ...@@ -563,7 +576,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False]) @pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_replicated(dist_init, num_loras, device, stage, def test_linear_replicated(dist_init, num_loras, device, stage,
...@@ -571,7 +584,8 @@ def test_linear_replicated(dist_init, num_loras, device, stage, ...@@ -571,7 +584,8 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8 max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
...@@ -675,7 +689,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage, ...@@ -675,7 +689,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("orientation", ["row", "column"])
@pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False]) @pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
...@@ -683,7 +697,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, ...@@ -683,7 +697,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8 max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
...@@ -797,7 +812,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, ...@@ -797,7 +812,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False]) @pytest.mark.parametrize("bias_enabled", [True, False])
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
...@@ -805,7 +820,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, ...@@ -805,7 +820,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8 max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
...@@ -963,7 +979,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, ...@@ -963,7 +979,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seed = 0 seed = 0
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8 max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
......
...@@ -17,7 +17,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -17,7 +17,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
tensor_model_parallel_gather) tensor_model_parallel_gather)
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
from vllm.lora.punica import PunicaWrapper
# yapf: disable # yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, LinearBase,
...@@ -33,7 +32,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -33,7 +32,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
if TYPE_CHECKING: if TYPE_CHECKING:
pass from vllm.lora.punica_wrapper import PunicaWrapperBase
def _get_lora_device(base_layer: nn.Module) -> torch.device: def _get_lora_device(base_layer: nn.Module) -> torch.device:
...@@ -115,9 +114,9 @@ class BaseLayerWithLoRA(nn.Module): ...@@ -115,9 +114,9 @@ class BaseLayerWithLoRA(nn.Module):
def set_mapping( def set_mapping(
self, self,
punica_wrapper: PunicaWrapper, punica_wrapper,
): ):
self.punica_wrapper: PunicaWrapper = punica_wrapper self.punica_wrapper: PunicaWrapperBase = punica_wrapper
@classmethod @classmethod
def can_replace_layer( def can_replace_layer(
......
...@@ -21,7 +21,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ...@@ -21,7 +21,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
LinearScalingRotaryEmbeddingWithLora, LinearScalingRotaryEmbeddingWithLora,
LoRAMapping) LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.punica import PunicaWrapper from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor, from vllm.lora.utils import (from_layer, from_layer_logits_processor,
is_regex_target_modules, is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule) parse_fine_tuned_lora_name, replace_submodule)
...@@ -331,7 +331,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -331,7 +331,7 @@ class LoRAModelManager(AdapterModelManager):
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None self.long_lora_context: Optional[LongContextLoRAContext] = None
self.punica_wrapper = PunicaWrapper(max_num_batched_tokens, self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
max_batches=self.max_num_seqs, max_batches=self.max_num_seqs,
device=self.device) device=self.device)
# Scaling factor -> offset to the sin_cos_cache to it. # Scaling factor -> offset to the sin_cos_cache to it.
......
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper
__all__ = [
"PunicaWrapperBase",
"get_punica_wrapper",
]
...@@ -5,19 +5,12 @@ Punica: Multi-Tenant LoRA Serving. ...@@ -5,19 +5,12 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547 https://arxiv.org/abs/2310.18547
""" """
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch import torch
from vllm.triton_utils import HAS_TRITON from .utils import compute_meta, convert_mapping
if HAS_TRITON:
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
if TYPE_CHECKING: if TYPE_CHECKING:
# avoid circuit import # avoid circuit import
...@@ -25,166 +18,117 @@ if TYPE_CHECKING: ...@@ -25,166 +18,117 @@ if TYPE_CHECKING:
from vllm.lora.models import LongContextLoRAContext from vllm.lora.models import LongContextLoRAContext
def compute_meta( class PunicaWrapperABC(ABC):
token_lora_tensor: torch.Tensor """
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: PunicaWrapper ABC.
""" """
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function @abstractmethod
will combine them into a single request, improving sgmv kernel inference def update_metadata(
performance. self,
2. At the beginning of each prefill stage inference, recalculations are
needed based on the input, but only once.
"""
lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
token_lora_tensor, return_counts=True)
cum_result = torch.cumsum(seq_length_tensor, dim=0)
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
b_seq_start_tensor[1:].copy_(cum_result[:-1])
max_length = seq_length_tensor.max().item()
token_nums = seq_length_tensor.sum().item()
batch_size = lora_indices_tensor.size(0)
no_lora = False
# -1 means no lora should be applied. Use `no_lora` to determine whether
# the current step requires LoRA. If LoRA is not needed, the prefill stage
# does not need to launch the triton kernel, which can improve performance
if batch_size == 1 and lora_indices_tensor == -1:
no_lora = True
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, token_nums, no_lora)
# TODO see if this can be vectorized
def convert_mapping(
mapping: "LoRAMapping", mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]], lora_index_to_id: List[Optional[int]],
max_loras: int, max_loras: int,
vocab_size: int, vocab_size: int,
extra_vocab_size: int, extra_vocab_size: int,
device: torch.device,
long_lora_context: Optional["LongContextLoRAContext"] = None, long_lora_context: Optional["LongContextLoRAContext"] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, **kwargs,
Optional[torch.Tensor], List[int]]: ) -> None:
"""Converts LoRAMapping to index tensors. """
Update the lora-related metadata
"""
raise NotImplementedError
Args: @abstractmethod
mapping: LoRAMapping mapping rows in a batch to LoRA ids. def add_shrink(
lora_index_to_id: List mapping LoRA ids to LoRA indices. self,
max_loras: Maximum number of LoRAs. y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
vocab_size: Model vocab size. x: torch.Tensor,
extra_vocab_size: Extra vocab size each LoRA can have. lora_a_stacked: Tuple[torch.Tensor, ...],
long_lora_context: Passed if there are long context lora in a batch. scale: float,
**kwargs,
Returns: ) -> None:
A tuple of tensors: """
base_indices: Tensor of shape [batch_size] mapping batch rows to Performs GEMM for multiple slices of lora_a.
LoRA indices. """
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors. It contains
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices).
"""
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices,
lora_indices,
embedding_indices,
]
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
prompt_mapping_tensor = torch.tensor(prompt_mapping,
dtype=torch.long,
device=device)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size),
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = torch.arange(
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
sampler_indices_padded * len(sampler_indices_padded))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
long_lora_indices = indices[3]
long_lora_indices_len = long_lora_indices.shape[-1]
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1],
sampler_indices.shape[-1],
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1],
]
if long_lora_indices_len is not None:
indices_len.append(long_lora_indices_len)
else:
# If long_lora doesn't exist,append None
indices_len.append(None)
return ( raise NotImplementedError
base_indices,
sampler_indices, @abstractmethod
sampler_indices_padded, def add_expand(
embeddings_indices, self,
long_lora_indices, y: torch.Tensor,
indices_len, x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
) lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_input=True,
**kwargs,
) -> None:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
"""
raise NotImplementedError
@abstractmethod
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_input: bool = True,
**kwargs,
) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA,
and this layer only requires the expand operation.
"""
raise NotImplementedError
class PunicaWrapper: @abstractmethod
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
"""
Applicable to linear-related lora.
""" """
PunicaWrapper is designed to manage and provide metadata for the punica
raise NotImplementedError
@abstractmethod
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
"""
raise NotImplementedError
class PunicaWrapperBase(PunicaWrapperABC):
"""
PunicaWrapperBase is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica kernel. Multi-LoRA, and to provide the interface for the punica.
""" """
def __init__(self, max_num_batched_tokens: int, max_batches: int, def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: Union[torch.device, str]): device: Union[torch.device, str], **kwargs):
self._token_lora_indices = torch.empty(max_num_batched_tokens, self._token_lora_indices = torch.empty(max_num_batched_tokens,
dtype=torch.long, dtype=torch.long,
device=device) device=device)
...@@ -223,26 +167,6 @@ class PunicaWrapper: ...@@ -223,26 +167,6 @@ class PunicaWrapper:
self.is_prefill = False self.is_prefill = False
self.no_lora = False self.no_lora = False
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
):
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
vocab_size, extra_vocab_size,
long_lora_context)
if mapping.is_prefill:
# Update metadata required for prefill-related operators.
self._update_prefill_metada(self.token_lora_indices)
self.is_prefill = True
else:
self.is_prefill = False
def _update_base_metadata( def _update_base_metadata(
self, self,
mapping: "LoRAMapping", mapping: "LoRAMapping",
...@@ -298,6 +222,38 @@ class PunicaWrapper: ...@@ -298,6 +222,38 @@ class PunicaWrapper:
self.token_nums = token_nums self.token_nums = token_nums
self.no_lora = no_lora self.no_lora = no_lora
def _apply_bias(
self,
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
lora_bias_stacked: Tuple[Optional[torch.Tensor], ...],
):
"""Applies bias to output
Input shapes:
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = lora_bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias[indices == -1] = 0
output[:, offset_left:offset_left + slice] += bias
offset_left += slice
return output.view_as(org_output)
@property @property
def prefill_metadata( def prefill_metadata(
self self
...@@ -362,179 +318,32 @@ class PunicaWrapper: ...@@ -362,179 +318,32 @@ class PunicaWrapper:
long_lora_len = self.indices_len[4] long_lora_len = self.indices_len[4]
return self._long_lora_indices[:long_lora_len] return self._long_lora_indices[:long_lora_len]
def _shrink_prefill( def update_metadata(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
def _shrink_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def _expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand(
x,
w_t_all,
y,
*self.prefill_metadata,
add_input,
)
def _expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool,
):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_input,
)
def _expand_slice_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool,
):
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_input)
def _apply_expand(self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool = True):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun: Callable = (self._expand_slice_prefill
if self.is_prefill else
self._expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
def _apply_bias(
self, self,
indices: torch.Tensor, mapping: "LoRAMapping",
output: torch.Tensor, lora_index_to_id: List[Optional[int]],
output_slices: Tuple[int, ...], max_loras: int,
lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], vocab_size: int,
): extra_vocab_size: int,
"""Applies bias to output long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs):
Input shapes:
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = lora_bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias[indices == -1] = 0
output[:, offset_left:offset_left + slice] += bias
offset_left += slice
return output.view_as(org_output)
def _apply_shrink( self._update_base_metadata(mapping, lora_index_to_id, max_loras,
self, vocab_size, extra_vocab_size,
y: torch.Tensor, long_lora_context)
x: torch.Tensor, if mapping.is_prefill:
w_t_all: torch.Tensor, # Update metadata required for prefill-related operators.
scale: float, self._update_prefill_metada(self.token_lora_indices)
): self.is_prefill = True
""" else:
Perform the ` y+=x@w_t_all` computation, which is suitable for the self.is_prefill = False
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org = y
y = y.view(-1, y.shape[-1])
shrink_fun: Callable = (self._shrink_prefill
if self.is_prefill else self._shrink_decode)
shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org)
def add_shrink( @abstractmethod
self, def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
x: torch.Tensor, scale: float, **kwargs) -> None:
lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float,
):
""" """
Performs GEMM for multiple slices of lora_a. Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics: Semantics:
for i in range(len(lora_a_stacked)): for i in range(len(lora_a_stacked)):
...@@ -545,16 +354,13 @@ class PunicaWrapper: ...@@ -545,16 +354,13 @@ class PunicaWrapper:
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1]) """
# TODO fuse these kernels # TODO: implement it based on torch ops
for slice_idx in range(len(lora_a_stacked)): raise NotImplementedError
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
scale)
def add_expand( @abstractmethod
self, def add_expand(self,
y: torch.Tensor, y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...], lora_b_stacked: Tuple[torch.Tensor, ...],
...@@ -562,7 +368,7 @@ class PunicaWrapper: ...@@ -562,7 +368,7 @@ class PunicaWrapper:
output_slices: Tuple[int, ...], output_slices: Tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_input=True, add_input=True,
) -> None: **kwargs) -> None:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM and bias addition for multiple slices of lora_b.
...@@ -581,35 +387,21 @@ class PunicaWrapper: ...@@ -581,35 +387,21 @@ class PunicaWrapper:
bias's weight bias's weight
output_slices (Tuple[int, ...]): Every slice's size output_slices (Tuple[int, ...]): Every slice's size
add_input (bool): Defaults to True. add_input (bool): Defaults to True.
""" """
y_org = y # TODO: implement it based on torch ops
y = y.view(-1, y.shape[-1]) raise NotImplementedError
offset_left = offset_start
if lora_bias_stacked is not None:
self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_input=add_input,
)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_embedding( @abstractmethod
self, def add_lora_embedding(self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_b_stacked: torch.Tensor, lora_b_stacked: torch.Tensor,
add_input: bool = True, add_input: bool = True,
): **kwargs) -> None:
""" """
Applies lora specifically for VocabParallelEmbeddingWithLoRA. Applies lora specifically for VocabParallelEmbeddingWithLoRA.
and this layer only requires the expand operation.
Semantics: Semantics:
y += x @ lora_b_stacked y += x @ lora_b_stacked
...@@ -618,16 +410,12 @@ class PunicaWrapper: ...@@ -618,16 +410,12 @@ class PunicaWrapper:
x (torch.Tensor): Input tensor. x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights. lora_b_stacked (torch.Tensor): lora_b's weights.
add_input (bool): Default to True. add_input (bool): Default to True.
""" """
# TODO: implement it based on torch ops
raise NotImplementedError
# Embedding layer only need expand op @abstractmethod
expand_fun: Callable = (self._expand_prefill def add_lora_linear(self,
if self.is_prefill else self._expand_decode)
expand_fun(y, x, lora_b_stacked, add_input)
def add_lora_linear(
self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...], lora_a_stacked: Tuple[torch.Tensor, ...],
...@@ -636,7 +424,8 @@ class PunicaWrapper: ...@@ -636,7 +424,8 @@ class PunicaWrapper:
scale: float, scale: float,
output_slices: Tuple[int, ...], output_slices: Tuple[int, ...],
*, *,
buffer: Optional[Tuple[torch.Tensor, ...]] = None) -> None: buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
""" """
Applicable to linear-related lora. Applicable to linear-related lora.
...@@ -659,29 +448,10 @@ class PunicaWrapper: ...@@ -659,29 +448,10 @@ class PunicaWrapper:
output_slices (Tuple[int, ...]): Every slice's size. output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
""" """
# TODO: implement it based on torch ops
raise NotImplementedError
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) @abstractmethod
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = tuple(
torch.zeros(
(x.size(0), r), dtype=torch.float32, device=x.device)
for _ in range(len(output_slices)))
self.add_shrink(buffer, x, lora_a_stacked, scale)
self.add_expand(y,
buffer,
lora_b_stacked,
None,
output_slices,
add_input=True)
def add_lora_logits(self, def add_lora_logits(self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
...@@ -689,7 +459,8 @@ class PunicaWrapper: ...@@ -689,7 +459,8 @@ class PunicaWrapper:
lora_b_stacked: torch.Tensor, lora_b_stacked: torch.Tensor,
scale, scale,
*, *,
buffer: Optional[torch.Tensor] = None) -> None: buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
""" """
Applies lora specifically for LogitsProcessorWithLoRA. Applies lora specifically for LogitsProcessorWithLoRA.
...@@ -705,21 +476,5 @@ class PunicaWrapper: ...@@ -705,21 +476,5 @@ class PunicaWrapper:
scale (float): Scaling factor. scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None. buffer (Optional[torch.Tensor]):Default to None.
""" """
y_org = y # TODO: implement it based on torch ops
y = y.view(-1, y.shape[-1]) raise NotImplementedError
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
# LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
add_inputs=True)
y = y.view_as(y_org)
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import Callable, Optional, Tuple, Union, final
import torch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from .punica_base import PunicaWrapperBase
@final
class PunicaWrapperGPU(PunicaWrapperBase):
"""
PunicaWrapperGPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica triton kernel.
"""
def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: Union[torch.device, str], **kwargs):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device)
def _shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
def _shrink_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def _expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand(
x,
w_t_all,
y,
*self.prefill_metadata,
add_input,
)
def _expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool,
):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_input,
)
def _expand_slice_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool,
):
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_input)
def _apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool = True,
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun: Callable = (self._expand_slice_prefill
if self.is_prefill else
self._expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, scale: float):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org = y
y = y.view(-1, y.shape[-1])
shrink_fun: Callable = (self._shrink_prefill
if self.is_prefill else self._shrink_decode)
shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org)
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float, **kwargs):
"""
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
scale)
def add_expand(self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_input=True,
**kwargs) -> None:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
add_input (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = offset_start
if lora_bias_stacked is not None:
self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_input=add_input,
)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_input: bool = True,
**kwargs) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_input (bool): Default to True.
"""
# Embedding layer only need expand op
expand_fun: Callable = (self._expand_prefill
if self.is_prefill else self._expand_decode)
expand_fun(y, x, lora_b_stacked, add_input)
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)+lora_bias_stacked[i]
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = tuple(
torch.zeros(
(x.size(0), r), dtype=torch.float32, device=x.device)
for _ in range(len(output_slices)))
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand(y,
buffer,
lora_b_stacked,
None,
output_slices,
add_input=True,
**kwargs)
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
# LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
add_inputs=True)
y = y.view_as(y_org)
from vllm.platforms import current_platform
from vllm.utils import print_info_once
from .punica_base import PunicaWrapperBase
def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
if current_platform.is_cuda_alike():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
print_info_once("Using PunicaWrapperGPU.")
return PunicaWrapperGPU(*args, **kwargs)
else:
raise NotImplementedError
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
def compute_meta(
token_lora_tensor: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
will combine them into a single request, improving sgmv kernel inference
performance.
2. At the beginning of each prefill stage inference, recalculations are
needed based on the input, but only once.
"""
lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
token_lora_tensor, return_counts=True)
cum_result = torch.cumsum(seq_length_tensor, dim=0)
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
b_seq_start_tensor[1:].copy_(cum_result[:-1])
max_length = seq_length_tensor.max().item()
token_nums = seq_length_tensor.sum().item()
batch_size = lora_indices_tensor.size(0)
no_lora = False
# -1 means no lora should be applied. Use `no_lora` to determine whether
# the current step requires LoRA. If LoRA is not needed, the prefill stage
# does not need to launch the triton kernel, which can improve performance
if batch_size == 1 and lora_indices_tensor == -1:
no_lora = True
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, token_nums, no_lora)
# TODO see if this can be vectorized
def convert_mapping(
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
device: torch.device,
long_lora_context: Optional["LongContextLoRAContext"] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], List[int]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors. It contains
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices).
"""
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices,
lora_indices,
embedding_indices,
]
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
prompt_mapping_tensor = torch.tensor(prompt_mapping,
dtype=torch.long,
device=device)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size),
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = torch.arange(
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
sampler_indices_padded * len(sampler_indices_padded))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
long_lora_indices = indices[3]
long_lora_indices_len = long_lora_indices.shape[-1]
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1],
sampler_indices.shape[-1],
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1],
]
if long_lora_indices_len is not None:
indices_len.append(long_lora_indices_len)
else:
# If long_lora doesn't exist,append None
indices_len.append(None)
return (
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_indices,
indices_len,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment