Unverified Commit 571da8fc authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc][LoRA] Clean up the function interface of Punica (#10917)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 39c89e71
......@@ -565,7 +565,9 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_replicated(dist_init, num_loras, device, stage,
bias_enabled) -> None:
torch.cuda.set_device(device)
torch.set_default_device(device)
......@@ -573,7 +575,8 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
lora_dtype=torch.float16,
bias_enabled=bias_enabled)
def create_random_linear_replicated_layer():
......@@ -585,7 +588,12 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
lora_linear = ReplicatedLinearWithLoRA(linear)
lora_linear.create_lora_weights(max_loras, lora_config)
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
lora_linear.lora_b_stacked) == 1)
if bias_enabled:
assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
else:
assert lora_linear.lora_bias_stacked is None
return linear, lora_linear
for i in range(10):
......@@ -669,8 +677,9 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device, stage) -> None:
device, stage, bias_enabled) -> None:
torch.cuda.set_device(device)
torch.set_default_device(device)
......@@ -679,7 +688,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
fully_sharded_loras=fully_shard,
lora_dtype=torch.float16)
lora_dtype=torch.float16,
bias_enabled=bias_enabled)
def create_random_linear_parallel_layer():
if orientation == "row":
......@@ -700,7 +710,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
if not fully_shard else
ColumnParallelLinearWithShardedLoRA(linear))
lora_linear.create_lora_weights(max_loras, lora_config)
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
lora_linear.lora_b_stacked) == 1)
if bias_enabled:
assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
else:
assert lora_linear.lora_bias_stacked is None
return linear, lora_linear
for i in range(10):
......@@ -784,8 +799,9 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device, stage) -> None:
device, stage, bias_enabled) -> None:
torch.cuda.set_device(device)
torch.set_default_device(device)
......@@ -794,7 +810,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
fully_sharded_loras=fully_shard,
lora_dtype=torch.float16)
lora_dtype=torch.float16,
bias_enabled=bias_enabled)
def create_column_parallel_packed_layer():
if repeats == 2:
......@@ -832,10 +849,16 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
num_key_value_heads = 32
num_attention_heads = 32
n_slices = repeats
lora_linear.create_lora_weights(max_loras,
lora_config,
model_config=FakeConfig())
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
lora_linear.lora_b_stacked) == n_slices)
if bias_enabled:
assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
else:
assert lora_linear.lora_bias_stacked is None
return linear, lora_linear
for i in range(10):
......@@ -911,7 +934,6 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
512,
lora_config.lora_extra_vocab_size,
)
# lora_linear.set_mapping(*mapping_info)
lora_result = lora_linear(torch.cat(inputs))[0]
expected_result = linear(torch.cat(inputs))[0]
......
# pylint: disable=unused-argument
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
import torch
import torch.nn as nn
......@@ -32,6 +32,44 @@ def _fully_sharded_can_replace(can_replace):
return dec
def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA):
"""
For `ColumnParallelLinearWithLoRA` or classes that inherit from
`ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
"""
assert (layer.n_slices == len(layer.lora_a_stacked) == len(
layer.lora_b_stacked) == len(layer.output_slices))
if layer.lora_bias_stacked is not None:
assert layer.n_slices == len(layer.lora_bias_stacked)
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
# Since communication is needed, the buffer is directly initialized as a
# tensor rather than a tuple of tensor.
buffers = torch.zeros(
(layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0)
buffers = tensor_model_parallel_all_gather(buffers)
layer.punica_wrapper.add_expand(output,
buffers,
layer.lora_b_stacked,
layer.lora_bias_stacked,
layer.output_slices,
offset_start=0,
add_input=True)
output = output.view(*out_orig_shape)
# now have column partitioned and packed output
return output
# these layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
......@@ -51,34 +89,15 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
# gather operation.
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2]
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros(
(x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device,
)
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
self.punica_wrapper.add_expand(output,
buffer,
self.lora_b_stacked,
self.bias_stacked,
add_input=True)
# now have column partitioned output
output = output.view(*out_orig_shape)
return output
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
......@@ -99,46 +118,6 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
)
def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
"""
MergedColumnParallelLinearWithShardedLoRA and
MergedQKVParallelLinearWithShardedLora share the same
LoRa weight application method.
The main difference is the step by shard_size for lora_b which can
vary for MergedQKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA.
"""
# expecting 2 for column parallel and 3 for qkv
n = len(layer.lora_a_stacked)
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
buffers = torch.zeros(
(n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
for idx in range(n):
layer.punica_wrapper.add_shrink(buffers[idx], x,
layer.lora_a_stacked[idx], 1.0)
buffers = tensor_model_parallel_all_gather(buffers)
layer.punica_wrapper.add_expand_packed_nslice(
output,
buffers,
layer.lora_b_stacked,
layer.bias_stacked,
1.0,
layer.output_slices,
)
output = output.view(*out_orig_shape)
# now have column partitioned and packed output
return output
class MergedColumnParallelLinearWithShardedLoRA(
MergedColumnParallelLinearWithLoRA):
"""
......@@ -162,8 +141,9 @@ class MergedColumnParallelLinearWithShardedLoRA(
]
return lora_a
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
......@@ -195,31 +175,15 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2]
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
self.punica_wrapper.add_expand(output,
buffer,
self.lora_b_stacked,
self.bias_stacked,
add_input=True)
# now have column partitioned output
output = output.view(*out_orig_shape)
return output
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
......@@ -260,8 +224,9 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
]
return lora_a
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
......@@ -294,7 +259,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
"""
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
shard_size = self.lora_b_stacked.shape[2]
shard_size = self.lora_b_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
......@@ -303,20 +268,24 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None:
return bias
shard_size = self.bias_stacked.shape[2]
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
self.lora_bias_stacked)
shard_size = self.lora_bias_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias
def apply(self, x: torch.Tensor) -> torch.Tensor:
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros(
(x.shape[0], self.lora_a_stacked.shape[2]),
(self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
......@@ -330,12 +299,18 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size
self.punica_wrapper.add_expand_slice(output, buffer,
self.lora_b_stacked,
self.bias_stacked, start_idx,
shard_size)
# NOTE offset are based on the rank.
shard_size = self.lora_b_stacked[0].shape[2]
offset_start = self.tp_rank * shard_size
self.punica_wrapper.add_expand(
output,
buffer,
self.lora_b_stacked,
self.lora_bias_stacked,
self.output_slices,
offset_start=offset_start,
add_input=True,
)
output = output.view(*out_orig_shape)
return output
......
This diff is collapsed.
......@@ -555,17 +555,17 @@ class LoRAModelManager(AdapterModelManager):
input_dim,
output_dim,
rank,
module.lora_a_stacked.dtype,
module.lora_a_stacked[0].dtype,
"cpu",
embeddings_tensor_dim=embeddings_tensor_dim,
bias_enabled=bias_enabled)
else:
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.lora_a_stacked.shape[-1],
module.lora_b_stacked.shape[-2],
module.lora_a_stacked[0].shape[-1],
module.lora_b_stacked[0].shape[-2],
rank,
module.lora_a_stacked.dtype,
module.lora_a_stacked[0].dtype,
"cpu",
bias_enabled=bias_enabled,
)
......
......@@ -362,7 +362,7 @@ class PunicaWrapper:
long_lora_len = self.indices_len[4]
return self._long_lora_indices[:long_lora_len]
def shrink_prefill(
def _shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
......@@ -380,7 +380,7 @@ class PunicaWrapper:
scale,
)
def shrink_decode(
def _shrink_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
......@@ -389,7 +389,7 @@ class PunicaWrapper:
):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def expand_prefill(
def _expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
......@@ -407,7 +407,7 @@ class PunicaWrapper:
add_input,
)
def expand_decode(
def _expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
......@@ -416,7 +416,7 @@ class PunicaWrapper:
):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
def expand_slice_prefill(
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
......@@ -438,7 +438,7 @@ class PunicaWrapper:
add_input,
)
def expand_slice_decode(
def _expand_slice_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
......@@ -450,41 +450,35 @@ class PunicaWrapper:
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_input)
def apply_bias(
self,
indices: torch.Tensor,
output: torch.Tensor,
bias_stacked: torch.Tensor,
):
"""Applies bias to output
Input shapes:
bias_stacked: (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, output_dim)
def _apply_expand(self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool = True):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
bias_stacked = bias_stacked[indices]
bias_stacked[indices == -1] = 0
output += bias_stacked
return output.view_as(org_output)
expand_slice_fun: Callable = (self._expand_slice_prefill
if self.is_prefill else
self._expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
def apply_bias_packed_nslice(
def _apply_bias(
self,
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
bias_stacked: Tuple[Optional[torch.Tensor], ...],
lora_bias_stacked: Tuple[Optional[torch.Tensor], ...],
):
"""Applies bias to output
Input shapes:
bias_stacked: 3 element tuple of (num_loras, output_dim)
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
......@@ -496,7 +490,7 @@ class PunicaWrapper:
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = bias_stacked[slice_idx]
bias = lora_bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
......@@ -506,7 +500,7 @@ class PunicaWrapper:
return output.view_as(org_output)
def add_shrink(
def _apply_shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
......@@ -517,188 +511,215 @@ class PunicaWrapper:
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the shrink_decode function
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
shrink_fun: Callable = (self.shrink_prefill
if self.is_prefill else self.shrink_decode)
y_org = y
y = y.view(-1, y.shape[-1])
shrink_fun: Callable = (self._shrink_prefill
if self.is_prefill else self._shrink_decode)
shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org)
def add_expand(
def add_shrink(
self,
y: torch.Tensor,
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor,
w_t_all: torch.Tensor,
bias_all: Optional[torch.Tensor],
add_input: bool = True,
lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float,
):
"""
Perform the ` y+=x@w_t_all+bias` computation, which is suitable for the
GEMM of lora'b.
When `is_prefill` is true, it indicates that it is currently the
prefill stage, and the `expand_prefill` function should be called.
Otherwise, it is the decode stage, and the expand_decode function
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
if bias_all is not None:
y = self.apply_bias(self.token_lora_indices, y, bias_all)
expand_fun: Callable = (self.expand_prefill
if self.is_prefill else self.expand_decode)
expand_fun(y, x, w_t_all, add_input)
def add_expand_slice(self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
bias_all: Optional[torch.Tensor],
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool = True):
"""
Similar to `add_expand`
"""
if bias_all is not None:
y = self.apply_bias(self.token_lora_indices, y, bias_all)
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
expand_slice_fun: Callable = (self.expand_slice_prefill
if self.is_prefill else
self.expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
x = x.view(-1, x.shape[-1])
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
scale)
def add_expand_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
lora_b_stacked: Tuple[torch.Tensor, ...],
bias_stacked: Optional[Tuple[torch.Tensor,
...]],
scale: float,
output_slices: Tuple[int, ...]) -> None:
"""
Similar to `add_expand`
def add_expand(
self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_input=True,
) -> None:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
add_input (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = 0
if bias_stacked is not None:
self.apply_bias_packed_nslice(self.token_lora_indices, y,
output_slices, bias_stacked)
offset_left = offset_start
if lora_bias_stacked is not None:
self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
self.add_expand_slice(y,
x[slice_idx],
lora_b_stacked[slice_idx],
None,
offset_left,
output_slices[slice_idx],
add_input=True)
self._apply_expand(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_input=add_input,
)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora(self,
y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
bias_all: Optional[torch.Tensor],
scale: float,
y_offset: Optional[int] = None,
y_slice_size: Optional[int] = None,
*,
buffer: Optional[torch.Tensor] = None) -> None:
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_input: bool = True,
):
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_input (bool): Default to True.
"""
# Embedding layer only need expand op
expand_fun: Callable = (self._expand_prefill
if self.is_prefill else self._expand_decode)
expand_fun(y, x, lora_b_stacked, add_input)
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None) -> None:
"""
Applicable to linear-related lora.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)+bias[i]
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)+lora_bias_stacked[i]
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
wa_t_all (torch.Tensor): lora_a's weight
wb_t_all (torch.Tensor): lora_b's weight
bias_all: (torch.Tensor): lora's bias
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting
column of y.
y_slice_size (Optional[int], optional): Size of the y column slice.
buffer (Optional[torch.Tensor], optional): Defaults to None.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = wb_t_all.size(-1)
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
if bias_all is not None:
y = self.apply_bias(self.token_lora_indices, y, bias_all)
self.add_shrink(buffer, x, wa_t_all, scale)
if y_offset is None and y_slice_size is None:
self.add_expand(y, buffer, wb_t_all, bias_all=None, add_input=True)
else:
self.add_expand_slice(y,
buffer,
wb_t_all,
None,
y_offset,
y_slice_size,
add_input=True)
y = y.view_as(y_org)
def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
bias_all: Tuple[Optional[torch.Tensor],
...], scale: float,
output_slices: Tuple[int, ...]) -> None:
"""
Applies lora to each input. Similar to add_lora, This method is
used for layers that are composed of multiple sublayers
(slices) packed together.
"""
y_org = y
x = x.view(-1, x.shape[-1])
y = y.view(-1, y.shape[-1])
offset_left = 0
if bias_all is not None:
y = self.apply_bias_packed_nslice(self.token_lora_indices, y,
output_slices, bias_all)
# TODO fuse these kernels
for slice_idx in range(len(output_slices)):
self.add_lora(y, x, lora_a_stacked[slice_idx],
lora_b_stacked[slice_idx], None, scale, offset_left,
output_slices[slice_idx])
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
buffer = tuple(
torch.zeros(
(x.size(0), r), dtype=torch.float32, device=x.device)
for _ in range(len(output_slices)))
self.add_shrink(buffer, x, lora_a_stacked, scale)
self.add_expand(y,
buffer,
lora_b_stacked,
None,
output_slices,
add_input=True)
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None) -> None:
"""
LogitsProcessorWithLoRA always using bgmv
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = wb_t_all.size(-1)
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale)
bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True)
# LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
add_inputs=True)
y = y.view_as(y_org)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment