Unverified Commit ca871491 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc][LoRA] Abstract PunicaWrapper (#10955)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 3b61cb45
......@@ -28,7 +28,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
# yapf: enable
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
PackedLoRALayerWeights)
from vllm.lora.punica import PunicaWrapper
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
......@@ -48,11 +48,12 @@ TOLERANCES = {
torch.float32: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
CUDA_DEVICES = [
# TODO: Modify this based on platform
DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
# We will launch different triton kernels between the prefill and decode
#For GPU, we will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False]
......@@ -192,9 +193,18 @@ def create_random_inputs(
return inputs, index_mapping, prompt_mapping
def check_punica_wrapper(punica_wrapper) -> bool:
if current_platform.is_cuda_alike():
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
return type(punica_wrapper) is PunicaWrapperGPU
else:
return False
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
......@@ -205,7 +215,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
torch.set_default_device(device)
max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
......@@ -296,7 +307,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
# @pytest.mark.skip(
# reason="Fails when loras are in any slot other than the first.")
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
......@@ -305,7 +316,8 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
torch.cuda.set_device(device)
torch.set_default_device(device)
max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
......@@ -432,7 +444,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
@pytest.mark.parametrize("stage", STAGES)
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
......@@ -441,7 +453,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
torch.cuda.set_device(device)
torch.set_default_device(device)
max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
......@@ -563,7 +576,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_replicated(dist_init, num_loras, device, stage,
......@@ -571,7 +584,8 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
......@@ -675,7 +689,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
......@@ -683,7 +697,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
......@@ -797,7 +812,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
......@@ -805,7 +820,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
......@@ -963,7 +979,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seed = 0
current_platform.seed_everything(seed)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
......
......@@ -17,7 +17,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
from vllm.distributed.utils import divide
from vllm.lora.punica import PunicaWrapper
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
......@@ -33,7 +32,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
if TYPE_CHECKING:
pass
from vllm.lora.punica_wrapper import PunicaWrapperBase
def _get_lora_device(base_layer: nn.Module) -> torch.device:
......@@ -115,9 +114,9 @@ class BaseLayerWithLoRA(nn.Module):
def set_mapping(
self,
punica_wrapper: PunicaWrapper,
punica_wrapper,
):
self.punica_wrapper: PunicaWrapper = punica_wrapper
self.punica_wrapper: PunicaWrapperBase = punica_wrapper
@classmethod
def can_replace_layer(
......
......@@ -21,7 +21,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.punica import PunicaWrapper
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule)
......@@ -331,9 +331,9 @@ class LoRAModelManager(AdapterModelManager):
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None
self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device)
self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device)
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {}
......
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper
__all__ = [
"PunicaWrapperBase",
"get_punica_wrapper",
]
......@@ -5,19 +5,12 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from .utils import compute_meta, convert_mapping
if TYPE_CHECKING:
# avoid circuit import
......@@ -25,166 +18,117 @@ if TYPE_CHECKING:
from vllm.lora.models import LongContextLoRAContext
def compute_meta(
token_lora_tensor: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
class PunicaWrapperABC(ABC):
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
will combine them into a single request, improving sgmv kernel inference
performance.
2. At the beginning of each prefill stage inference, recalculations are
needed based on the input, but only once.
PunicaWrapper ABC.
"""
lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
token_lora_tensor, return_counts=True)
cum_result = torch.cumsum(seq_length_tensor, dim=0)
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
b_seq_start_tensor[1:].copy_(cum_result[:-1])
max_length = seq_length_tensor.max().item()
token_nums = seq_length_tensor.sum().item()
batch_size = lora_indices_tensor.size(0)
no_lora = False
# -1 means no lora should be applied. Use `no_lora` to determine whether
# the current step requires LoRA. If LoRA is not needed, the prefill stage
# does not need to launch the triton kernel, which can improve performance
if batch_size == 1 and lora_indices_tensor == -1:
no_lora = True
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, token_nums, no_lora)
# TODO see if this can be vectorized
def convert_mapping(
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
device: torch.device,
long_lora_context: Optional["LongContextLoRAContext"] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], List[int]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors. It contains
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices).
"""
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices,
lora_indices,
embedding_indices,
]
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
prompt_mapping_tensor = torch.tensor(prompt_mapping,
dtype=torch.long,
device=device)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size),
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = torch.arange(
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
sampler_indices_padded * len(sampler_indices_padded))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
long_lora_indices = indices[3]
long_lora_indices_len = long_lora_indices.shape[-1]
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1],
sampler_indices.shape[-1],
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1],
]
if long_lora_indices_len is not None:
indices_len.append(long_lora_indices_len)
else:
# If long_lora doesn't exist,append None
indices_len.append(None)
return (
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_indices,
indices_len,
)
class PunicaWrapper:
@abstractmethod
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs,
) -> None:
"""
Update the lora-related metadata
"""
raise NotImplementedError
@abstractmethod
def add_shrink(
self,
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> None:
"""
Performs GEMM for multiple slices of lora_a.
"""
raise NotImplementedError
@abstractmethod
def add_expand(
self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_input=True,
**kwargs,
) -> None:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
"""
raise NotImplementedError
@abstractmethod
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_input: bool = True,
**kwargs,
) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA,
and this layer only requires the expand operation.
"""
raise NotImplementedError
@abstractmethod
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
"""
Applicable to linear-related lora.
"""
raise NotImplementedError
@abstractmethod
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
"""
raise NotImplementedError
class PunicaWrapperBase(PunicaWrapperABC):
"""
PunicaWrapper is designed to manage and provide metadata for the punica
PunicaWrapperBase is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica kernel.
Multi-LoRA, and to provide the interface for the punica.
"""
def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: Union[torch.device, str]):
device: Union[torch.device, str], **kwargs):
self._token_lora_indices = torch.empty(max_num_batched_tokens,
dtype=torch.long,
device=device)
......@@ -223,26 +167,6 @@ class PunicaWrapper:
self.is_prefill = False
self.no_lora = False
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
):
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
vocab_size, extra_vocab_size,
long_lora_context)
if mapping.is_prefill:
# Update metadata required for prefill-related operators.
self._update_prefill_metada(self.token_lora_indices)
self.is_prefill = True
else:
self.is_prefill = False
def _update_base_metadata(
self,
mapping: "LoRAMapping",
......@@ -298,6 +222,38 @@ class PunicaWrapper:
self.token_nums = token_nums
self.no_lora = no_lora
def _apply_bias(
self,
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
lora_bias_stacked: Tuple[Optional[torch.Tensor], ...],
):
"""Applies bias to output
Input shapes:
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = lora_bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias[indices == -1] = 0
output[:, offset_left:offset_left + slice] += bias
offset_left += slice
return output.view_as(org_output)
@property
def prefill_metadata(
self
......@@ -362,180 +318,33 @@ class PunicaWrapper:
long_lora_len = self.indices_len[4]
return self._long_lora_indices[:long_lora_len]
def _shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
def _shrink_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def _expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand(
x,
w_t_all,
y,
*self.prefill_metadata,
add_input,
)
def _expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool,
):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_input,
)
def _expand_slice_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool,
):
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_input)
def _apply_expand(self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool = True):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun: Callable = (self._expand_slice_prefill
if self.is_prefill else
self._expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
def _apply_bias(
self,
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
lora_bias_stacked: Tuple[Optional[torch.Tensor], ...],
):
"""Applies bias to output
Input shapes:
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = lora_bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias[indices == -1] = 0
output[:, offset_left:offset_left + slice] += bias
offset_left += slice
return output.view_as(org_output)
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs):
def _apply_shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org = y
y = y.view(-1, y.shape[-1])
shrink_fun: Callable = (self._shrink_prefill
if self.is_prefill else self._shrink_decode)
shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org)
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
vocab_size, extra_vocab_size,
long_lora_context)
if mapping.is_prefill:
# Update metadata required for prefill-related operators.
self._update_prefill_metada(self.token_lora_indices)
self.is_prefill = True
else:
self.is_prefill = False
def add_shrink(
self,
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float,
):
@abstractmethod
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float, **kwargs) -> None:
"""
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
......@@ -545,24 +354,21 @@ class PunicaWrapper:
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
scale)
def add_expand(
self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_input=True,
) -> None:
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_expand(self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_input=True,
**kwargs) -> None:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
......@@ -581,35 +387,21 @@ class PunicaWrapper:
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
add_input (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = offset_start
if lora_bias_stacked is not None:
self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_input=add_input,
)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_input: bool = True,
):
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_input: bool = True,
**kwargs) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
and this layer only requires the expand operation.
Semantics:
y += x @ lora_b_stacked
......@@ -618,25 +410,22 @@ class PunicaWrapper:
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_input (bool): Default to True.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
# Embedding layer only need expand op
expand_fun: Callable = (self._expand_prefill
if self.is_prefill else self._expand_decode)
expand_fun(y, x, lora_b_stacked, add_input)
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None) -> None:
@abstractmethod
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
"""
Applicable to linear-related lora.
......@@ -659,29 +448,10 @@ class PunicaWrapper:
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = tuple(
torch.zeros(
(x.size(0), r), dtype=torch.float32, device=x.device)
for _ in range(len(output_slices)))
self.add_shrink(buffer, x, lora_a_stacked, scale)
self.add_expand(y,
buffer,
lora_b_stacked,
None,
output_slices,
add_input=True)
@abstractmethod
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
......@@ -689,7 +459,8 @@ class PunicaWrapper:
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None) -> None:
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
......@@ -704,22 +475,6 @@ class PunicaWrapper:
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
# LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
add_inputs=True)
y = y.view_as(y_org)
"""
# TODO: implement it based on torch ops
raise NotImplementedError
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import Callable, Optional, Tuple, Union, final
import torch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from .punica_base import PunicaWrapperBase
@final
class PunicaWrapperGPU(PunicaWrapperBase):
"""
PunicaWrapperGPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica triton kernel.
"""
def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: Union[torch.device, str], **kwargs):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device)
def _shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
def _shrink_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def _expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand(
x,
w_t_all,
y,
*self.prefill_metadata,
add_input,
)
def _expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool,
):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_input,
)
def _expand_slice_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool,
):
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_input)
def _apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool = True,
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun: Callable = (self._expand_slice_prefill
if self.is_prefill else
self._expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, scale: float):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org = y
y = y.view(-1, y.shape[-1])
shrink_fun: Callable = (self._shrink_prefill
if self.is_prefill else self._shrink_decode)
shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org)
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float, **kwargs):
"""
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
scale)
def add_expand(self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_input=True,
**kwargs) -> None:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
add_input (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = offset_start
if lora_bias_stacked is not None:
self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_input=add_input,
)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_input: bool = True,
**kwargs) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_input (bool): Default to True.
"""
# Embedding layer only need expand op
expand_fun: Callable = (self._expand_prefill
if self.is_prefill else self._expand_decode)
expand_fun(y, x, lora_b_stacked, add_input)
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)+lora_bias_stacked[i]
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = tuple(
torch.zeros(
(x.size(0), r), dtype=torch.float32, device=x.device)
for _ in range(len(output_slices)))
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand(y,
buffer,
lora_b_stacked,
None,
output_slices,
add_input=True,
**kwargs)
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
# LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
add_inputs=True)
y = y.view_as(y_org)
from vllm.platforms import current_platform
from vllm.utils import print_info_once
from .punica_base import PunicaWrapperBase
def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
if current_platform.is_cuda_alike():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
print_info_once("Using PunicaWrapperGPU.")
return PunicaWrapperGPU(*args, **kwargs)
else:
raise NotImplementedError
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
def compute_meta(
token_lora_tensor: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
will combine them into a single request, improving sgmv kernel inference
performance.
2. At the beginning of each prefill stage inference, recalculations are
needed based on the input, but only once.
"""
lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
token_lora_tensor, return_counts=True)
cum_result = torch.cumsum(seq_length_tensor, dim=0)
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
b_seq_start_tensor[1:].copy_(cum_result[:-1])
max_length = seq_length_tensor.max().item()
token_nums = seq_length_tensor.sum().item()
batch_size = lora_indices_tensor.size(0)
no_lora = False
# -1 means no lora should be applied. Use `no_lora` to determine whether
# the current step requires LoRA. If LoRA is not needed, the prefill stage
# does not need to launch the triton kernel, which can improve performance
if batch_size == 1 and lora_indices_tensor == -1:
no_lora = True
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, token_nums, no_lora)
# TODO see if this can be vectorized
def convert_mapping(
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
device: torch.device,
long_lora_context: Optional["LongContextLoRAContext"] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], List[int]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors. It contains
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices).
"""
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices,
lora_indices,
embedding_indices,
]
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
prompt_mapping_tensor = torch.tensor(prompt_mapping,
dtype=torch.long,
device=device)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size),
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = torch.arange(
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
sampler_indices_padded * len(sampler_indices_padded))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
long_lora_indices = indices[3]
long_lora_indices_len = long_lora_indices.shape[-1]
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1],
sampler_indices.shape[-1],
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1],
]
if long_lora_indices_len is not None:
indices_len.append(long_lora_indices_len)
else:
# If long_lora doesn't exist,append None
indices_len.append(None)
return (
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_indices,
indices_len,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment