"vllm/vscode:/vscode.git/clone" did not exist on "de92d916fe8a897b00a8adb0aab9ed9ec99f2b6c"
Unverified Commit ab196ede authored by Ashwin Phadke's avatar Ashwin Phadke Committed by GitHub
Browse files

Remove LoRA bias support (#25807)


Signed-off-by: default avatarAshwin Phadke <ashwinphadke12@rediffmail.com>
Signed-off-by: default avatarAshwin Phadke <23502062+ashwin-phadke@users.noreply.github.com>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 3ee202ea
...@@ -23,11 +23,6 @@ BADREQUEST_CASES = [ ...@@ -23,11 +23,6 @@ BADREQUEST_CASES = [
{"r": 1024}, {"r": 1024},
"is greater than max_lora_rank", "is greater than max_lora_rank",
), ),
(
"test_bias",
{"bias": "all"},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {"use_dora": True}, "does not yet support DoRA"), ("test_dora", {"use_dora": True}, "does not yet support DoRA"),
( (
"test_modules_to_save", "test_modules_to_save",
......
...@@ -16,11 +16,6 @@ ERROR_CASES = [ ...@@ -16,11 +16,6 @@ ERROR_CASES = [
{"r": 1024}, {"r": 1024},
"is greater than max_lora_rank", "is greater than max_lora_rank",
), ),
(
"test_bias",
{"bias": "all"},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {"use_dora": True}, "does not yet support DoRA"), ("test_dora", {"use_dora": True}, "does not yet support DoRA"),
( (
"test_modules_to_save", "test_modules_to_save",
......
...@@ -21,7 +21,6 @@ class LoRANameParserTestConfig(NamedTuple): ...@@ -21,7 +21,6 @@ class LoRANameParserTestConfig(NamedTuple):
name: str name: str
module_name: str module_name: str
is_lora_a: bool is_lora_a: bool
is_bias: bool
weights_mapper: Optional[WeightsMapper] = None weights_mapper: Optional[WeightsMapper] = None
...@@ -37,44 +36,37 @@ def test_parse_fine_tuned_lora_name_valid(): ...@@ -37,44 +36,37 @@ def test_parse_fine_tuned_lora_name_valid():
"base_model.model.model.embed_tokens.lora_embedding_A", "base_model.model.model.embed_tokens.lora_embedding_A",
"model.embed_tokens", "model.embed_tokens",
True, True,
False,
), ),
LoRANameParserTestConfig( LoRANameParserTestConfig(
"base_model.model.model.embed_tokens.lora_embedding_B", "base_model.model.model.embed_tokens.lora_embedding_B",
"model.embed_tokens", "model.embed_tokens",
False, False,
False,
), ),
LoRANameParserTestConfig( LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"model.layers.9.mlp.down_proj", "model.layers.9.mlp.down_proj",
True, True,
False,
), ),
LoRANameParserTestConfig( LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"model.layers.9.mlp.down_proj", "model.layers.9.mlp.down_proj",
False, False,
False,
), ),
LoRANameParserTestConfig( LoRANameParserTestConfig(
"language_model.layers.9.mlp.down_proj.lora_A.weight", "language_model.layers.9.mlp.down_proj.lora_A.weight",
"language_model.layers.9.mlp.down_proj", "language_model.layers.9.mlp.down_proj",
True, True,
False,
), ),
LoRANameParserTestConfig( LoRANameParserTestConfig(
"language_model.layers.9.mlp.down_proj.lora_B.weight", "language_model.layers.9.mlp.down_proj.lora_B.weight",
"language_model.layers.9.mlp.down_proj", "language_model.layers.9.mlp.down_proj",
False, False,
False,
), ),
# Test with WeightsMapper # Test with WeightsMapper
LoRANameParserTestConfig( LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"language_model.model.layers.9.mlp.down_proj", "language_model.model.layers.9.mlp.down_proj",
True, True,
False,
weights_mapper=WeightsMapper( weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."} orig_to_new_prefix={"model.": "language_model.model."}
), ),
...@@ -83,7 +75,6 @@ def test_parse_fine_tuned_lora_name_valid(): ...@@ -83,7 +75,6 @@ def test_parse_fine_tuned_lora_name_valid():
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"language_model.model.layers.9.mlp.down_proj", "language_model.model.layers.9.mlp.down_proj",
False, False,
False,
weights_mapper=WeightsMapper( weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."} orig_to_new_prefix={"model.": "language_model.model."}
), ),
...@@ -92,7 +83,6 @@ def test_parse_fine_tuned_lora_name_valid(): ...@@ -92,7 +83,6 @@ def test_parse_fine_tuned_lora_name_valid():
"model.layers.9.mlp.down_proj.lora_A.weight", "model.layers.9.mlp.down_proj.lora_A.weight",
"language_model.model.layers.9.mlp.down_proj", "language_model.model.layers.9.mlp.down_proj",
True, True,
False,
weights_mapper=WeightsMapper( weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."} orig_to_new_prefix={"model.": "language_model.model."}
), ),
...@@ -101,14 +91,13 @@ def test_parse_fine_tuned_lora_name_valid(): ...@@ -101,14 +91,13 @@ def test_parse_fine_tuned_lora_name_valid():
"model.layers.9.mlp.down_proj.lora_B.weight", "model.layers.9.mlp.down_proj.lora_B.weight",
"language_model.model.layers.9.mlp.down_proj", "language_model.model.layers.9.mlp.down_proj",
False, False,
False,
weights_mapper=WeightsMapper( weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."} orig_to_new_prefix={"model.": "language_model.model."}
), ),
), ),
] ]
for name, module_name, is_lora_a, is_bias, weights_mapper in fixture: for name, module_name, is_lora_a, weights_mapper in fixture:
assert (module_name, is_lora_a, is_bias) == parse_fine_tuned_lora_name( assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(
name, weights_mapper name, weights_mapper
) )
......
...@@ -70,12 +70,6 @@ class LoRAConfig: ...@@ -70,12 +70,6 @@ class LoRAConfig:
per prompt. When run in offline mode, the lora IDs for n modalities per prompt. When run in offline mode, the lora IDs for n modalities
will be automatically assigned to 1-n with the names of the modalities will be automatically assigned to 1-n with the names of the modalities
in alphabetic order.""" in alphabetic order."""
bias_enabled: bool = Field(
default=False,
deprecated="`bias_enabled` is deprecated and will be removed in v0.12.0.",
)
"""[DEPRECATED] Enable bias for LoRA adapters. This option will be
removed in v0.12.0."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
...@@ -96,7 +90,7 @@ class LoRAConfig: ...@@ -96,7 +90,7 @@ class LoRAConfig:
factors.append(self.lora_dtype) factors.append(self.lora_dtype)
factors.append(self.lora_extra_vocab_size) factors.append(self.lora_extra_vocab_size)
factors.append(self.lora_vocab_padding_size) factors.append(self.lora_vocab_padding_size)
factors.append(self.bias_enabled)
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str
......
...@@ -439,7 +439,6 @@ class EngineArgs: ...@@ -439,7 +439,6 @@ class EngineArgs:
video_pruning_rate: float = MultiModalConfig.video_pruning_rate video_pruning_rate: float = MultiModalConfig.video_pruning_rate
# LoRA fields # LoRA fields
enable_lora: bool = False enable_lora: bool = False
enable_lora_bias: bool = LoRAConfig.bias_enabled
max_loras: int = LoRAConfig.max_loras max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = LoRAConfig.max_lora_rank max_lora_rank: int = LoRAConfig.max_lora_rank
default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras
...@@ -916,7 +915,6 @@ class EngineArgs: ...@@ -916,7 +915,6 @@ class EngineArgs:
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,
help="If True, enable handling of LoRA adapters.", help="If True, enable handling of LoRA adapters.",
) )
lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"])
lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
lora_group.add_argument( lora_group.add_argument(
...@@ -1515,7 +1513,6 @@ class EngineArgs: ...@@ -1515,7 +1513,6 @@ class EngineArgs:
lora_config = ( lora_config = (
LoRAConfig( LoRAConfig(
bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras, max_loras=self.max_loras,
default_mm_loras=self.default_mm_loras, default_mm_loras=self.default_mm_loras,
......
...@@ -45,7 +45,6 @@ class BaseLayerWithLoRA(nn.Module): ...@@ -45,7 +45,6 @@ class BaseLayerWithLoRA(nn.Module):
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
"""Overwrites lora tensors at index.""" """Overwrites lora tensors at index."""
... ...
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, cast from typing import Optional
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -29,7 +29,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ...@@ -29,7 +29,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
self.tp_size = self.base_layer.tp_size self.tp_size = self.base_layer.tp_size
self.tp_rank = self.base_layer.tp_rank self.tp_rank = self.base_layer.tp_rank
self.device = _get_lora_device(self.base_layer) self.device = _get_lora_device(self.base_layer)
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
self.output_slices: tuple[int, ...] self.output_slices: tuple[int, ...]
self.output_size: int self.output_size: int
self.n_slices: int self.n_slices: int
...@@ -86,30 +85,12 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ...@@ -86,30 +85,12 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
) )
for _ in range(self.n_slices) for _ in range(self.n_slices)
) )
if lora_config.bias_enabled:
lora_bias_out_size = lora_b_out_size
self.lora_bias_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_bias_out_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self.n_slices)
)
self.output_slices = (self.lora_b_stacked[0].shape[2],) self.output_slices = (self.lora_b_stacked[0].shape[2],)
def reset_lora(self, index: int): def reset_lora(self, index: int):
for s_index in range(self.n_slices): for s_index in range(self.n_slices):
self.lora_a_stacked[s_index][index] = 0 self.lora_a_stacked[s_index][index] = 0
self.lora_b_stacked[s_index][index] = 0 self.lora_b_stacked[s_index][index] = 0
if self.lora_config.bias_enabled:
# Make mypy happy
self.lora_bias_stacked = cast(
tuple[torch.Tensor, ...], self.lora_bias_stacked
)
self.lora_bias_stacked[s_index][index] = 0
def set_lora( def set_lora(
self, self,
...@@ -117,7 +98,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ...@@ -117,7 +98,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
lora_bias: Optional[torch.Tensor] = None,
): ):
# Except for QKVParallelLinearWithLoRA and # Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
...@@ -131,8 +111,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ...@@ -131,8 +111,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
if self.tp_size > 1: if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a) lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b) lora_b = self.slice_lora_b(lora_b)
if lora_bias is not None:
lora_bias = self.slice_bias(lora_bias)
self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
lora_a, non_blocking=True lora_a, non_blocking=True
...@@ -140,14 +118,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ...@@ -140,14 +118,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_( self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
lora_b, non_blocking=True lora_b, non_blocking=True
) )
if lora_bias is not None:
self.lora_bias_stacked = cast(
tuple[torch.Tensor, ...], self.lora_bias_stacked
)
assert len(self.lora_bias_stacked)
self.lora_bias_stacked[0][index, 0, : lora_bias.shape[0]].copy_(
lora_bias, non_blocking=True
)
def apply( def apply(
self, x: torch.Tensor, bias: Optional[torch.Tensor] = None self, x: torch.Tensor, bias: Optional[torch.Tensor] = None
...@@ -162,13 +132,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ...@@ -162,13 +132,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
x = x.flatten(0, 1) x = x.flatten(0, 1)
lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_linear( lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_linear(
output, output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.lora_bias_stacked,
1.0,
self.output_slices,
) )
if not current_platform.can_update_inplace(): if not current_platform.can_update_inplace():
output = lora_output output = lora_output
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union, cast from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -32,8 +32,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"): ...@@ -32,8 +32,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
== len(layer.lora_b_stacked) == len(layer.lora_b_stacked)
== len(layer.output_slices) == len(layer.output_slices)
) )
if layer.lora_bias_stacked is not None:
assert layer.n_slices == len(layer.lora_bias_stacked)
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
...@@ -61,7 +59,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"): ...@@ -61,7 +59,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
output, output,
buffers, buffers,
layer.lora_b_stacked, layer.lora_b_stacked,
layer.lora_bias_stacked,
layer.output_slices, layer.output_slices,
offset_start=0, offset_start=0,
add_input=True, add_input=True,
...@@ -122,16 +119,6 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -122,16 +119,6 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
lora_b = lora_b[start_idx:end_idx, :] lora_b = lora_b[start_idx:end_idx, :]
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
# TODO: Fix the slicing logic of bias.
if bias is None:
return bias
shard_size = self.output_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias
def forward( def forward(
self, input_: torch.Tensor self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
...@@ -238,17 +225,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -238,17 +225,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
) )
for output_size in self.output_slices for output_size in self.output_slices
) )
if lora_config.bias_enabled:
self.lora_bias_stacked = tuple(
torch.zeros(
max_loras,
1,
output_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
for output_size in self.output_slices
)
def slice_lora_a( def slice_lora_a(
self, lora_a: list[Union[torch.Tensor, None]] self, lora_a: list[Union[torch.Tensor, None]]
...@@ -268,31 +244,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -268,31 +244,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
] ]
return sliced_lora_b return sliced_lora_b
def slice_bias(
self, bias: list[Union[torch.Tensor, None]]
) -> list[Union[torch.Tensor, None]]:
for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)
):
if (bias_i := bias[i]) is not None:
bias[i] = bias_i[shard_size * shard_id : shard_size * (shard_id + 1)]
return bias
def set_lora( def set_lora(
self, self,
index: int, index: int,
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
lora_bias: Optional[torch.Tensor] = None,
): ):
self.reset_lora(index) self.reset_lora(index)
if self.tp_size > 1: if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a) lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b) lora_b = self.slice_lora_b(lora_b)
if lora_bias is not None:
lora_bias = self.slice_bias(lora_bias)
for i in range(self.n_slices): for i in range(self.n_slices):
if (lora_a_i := lora_a[i]) is not None: if (lora_a_i := lora_a[i]) is not None:
...@@ -304,16 +267,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -304,16 +267,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1] index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1]
].copy_(lora_b_i, non_blocking=True) ].copy_(lora_b_i, non_blocking=True)
if lora_bias is not None:
self.lora_bias_stacked = cast(
tuple[torch.Tensor, ...], self.lora_bias_stacked
)
for i in range(self.n_slices):
if (lora_bias_i := lora_bias[i]) is not None:
self.lora_bias_stacked[i][index, 0, : lora_bias_i.shape[0]].copy_(
lora_bias_i, non_blocking=True
)
@classmethod @classmethod
@_not_fully_sharded_can_replace @_not_fully_sharded_can_replace
def can_replace_layer( def can_replace_layer(
...@@ -380,24 +333,6 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -380,24 +333,6 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0) lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
bias_q = bias[
self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size
* (self.q_shard_id + 1)
]
k_offset = self.q_proj_total_size
bias_k = bias[
k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)
]
v_offset = k_offset + self.kv_proj_total_size
bias_v = bias[
v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)
]
bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
return bias
@classmethod @classmethod
@_not_fully_sharded_can_replace @_not_fully_sharded_can_replace
def can_replace_layer( def can_replace_layer(
......
...@@ -143,7 +143,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -143,7 +143,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
self.reset_lora(index) self.reset_lora(index)
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union, cast from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -39,9 +39,6 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -39,9 +39,6 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
return bias
def forward( def forward(
self, input_: torch.Tensor self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
...@@ -123,16 +120,6 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): ...@@ -123,16 +120,6 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
lora_b = lora_b[start_idx:end_idx, :] lora_b = lora_b[start_idx:end_idx, :]
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None:
return bias
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked)
shard_size = self.lora_bias_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias
def apply( def apply(
self, x: torch.Tensor, bias: Optional[torch.Tensor] = None self, x: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -167,7 +154,6 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): ...@@ -167,7 +154,6 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
output, output,
buffer, buffer,
self.lora_b_stacked, self.lora_b_stacked,
self.lora_bias_stacked,
self.output_slices, self.output_slices,
offset_start=offset_start, offset_start=offset_start,
add_input=True, add_input=True,
......
...@@ -91,7 +91,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -91,7 +91,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
self.reset_lora(index) self.reset_lora(index)
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major, # NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
......
...@@ -21,7 +21,6 @@ class LoRALayerWeights: ...@@ -21,7 +21,6 @@ class LoRALayerWeights:
lora_alpha: int, lora_alpha: int,
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
bias: Optional[torch.Tensor] = None,
embeddings_tensor: Optional[torch.Tensor] = None, embeddings_tensor: Optional[torch.Tensor] = None,
scaling: Optional[float] = None, scaling: Optional[float] = None,
) -> None: ) -> None:
...@@ -30,7 +29,6 @@ class LoRALayerWeights: ...@@ -30,7 +29,6 @@ class LoRALayerWeights:
self.lora_alpha = lora_alpha self.lora_alpha = lora_alpha
self.lora_a = lora_a self.lora_a = lora_a
self.lora_b = lora_b self.lora_b = lora_b
self.bias = bias
self.embeddings_tensor = embeddings_tensor self.embeddings_tensor = embeddings_tensor
if scaling is None: if scaling is None:
...@@ -71,13 +69,13 @@ class LoRALayerWeights: ...@@ -71,13 +69,13 @@ class LoRALayerWeights:
peft_helper: PEFTHelper, peft_helper: PEFTHelper,
embeddings_tensor: Optional[torch.Tensor] = None, embeddings_tensor: Optional[torch.Tensor] = None,
) -> "LoRALayerWeights": ) -> "LoRALayerWeights":
# lora_a and lora_b are set to None for config-based construction
return cls( return cls(
module_name, module_name,
peft_helper.r, peft_helper.r,
peft_helper.lora_alpha, peft_helper.lora_alpha,
None, None,
None, None,
None,
embeddings_tensor, embeddings_tensor,
peft_helper.vllm_lora_scaling_factor, peft_helper.vllm_lora_scaling_factor,
) )
...@@ -92,7 +90,6 @@ class LoRALayerWeights: ...@@ -92,7 +90,6 @@ class LoRALayerWeights:
dtype: torch.dtype, dtype: torch.dtype,
device: torch.types.Device, device: torch.types.Device,
embeddings_tensor_dim: Optional[int] = None, embeddings_tensor_dim: Optional[int] = None,
bias_enabled: Optional[bool] = False,
) -> "LoRALayerWeights": ) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available() pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros( lora_a = torch.zeros(
...@@ -101,12 +98,6 @@ class LoRALayerWeights: ...@@ -101,12 +98,6 @@ class LoRALayerWeights:
lora_b = torch.zeros( lora_b = torch.zeros(
[output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory [output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory
) )
if bias_enabled:
bias = torch.zeros(
[output_dim], dtype=dtype, device=device, pin_memory=pin_memory
)
else:
bias = None
embeddings_tensor = ( embeddings_tensor = (
torch.rand( torch.rand(
...@@ -125,7 +116,6 @@ class LoRALayerWeights: ...@@ -125,7 +116,6 @@ class LoRALayerWeights:
lora_alpha=1, lora_alpha=1,
lora_a=lora_a, lora_a=lora_a,
lora_b=lora_b, lora_b=lora_b,
bias=bias,
embeddings_tensor=embeddings_tensor, embeddings_tensor=embeddings_tensor,
) )
...@@ -140,7 +130,6 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -140,7 +130,6 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alphas: list[Optional[int]], lora_alphas: list[Optional[int]],
lora_a: list[Optional[torch.Tensor]], lora_a: list[Optional[torch.Tensor]],
lora_b: list[Optional[torch.Tensor]], lora_b: list[Optional[torch.Tensor]],
bias: Optional[list[Optional[torch.Tensor]]] = None,
scaling: Optional[list[float]] = None, scaling: Optional[list[float]] = None,
) -> None: ) -> None:
super().__init__( super().__init__(
...@@ -149,7 +138,6 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -149,7 +138,6 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alpha=0, lora_alpha=0,
lora_a=lora_a, lora_a=lora_a,
lora_b=lora_b, lora_b=lora_b,
bias=bias,
scaling=scaling, # type: ignore scaling=scaling, # type: ignore
embeddings_tensor=None, embeddings_tensor=None,
) )
...@@ -181,7 +169,6 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -181,7 +169,6 @@ class PackedLoRALayerWeights(LoRALayerWeights):
[lora.lora_alpha if lora is not None else None for lora in loras], [lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras], [lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras], [lora.lora_b if lora is not None else None for lora in loras],
[lora.bias if lora is not None else None for lora in loras],
scaling=[ scaling=[
1 if lora is not None else None # type: ignore 1 if lora is not None else None # type: ignore
for lora in loras for lora in loras
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import math import math
import os import os
from collections.abc import Sequence
from typing import Callable, Optional, TypeVar, Union from typing import Callable, Optional, TypeVar, Union
import regex as re import regex as re
...@@ -140,7 +139,7 @@ class LoRAModel: ...@@ -140,7 +139,7 @@ class LoRAModel:
pin_memory = str(device) == "cpu" and is_pin_memory_available() pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: dict[str, LoRALayerWeights] = {} loras: dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items(): for tensor_name, tensor in tensors.items():
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( module_name, is_lora_a = parse_fine_tuned_lora_name(
tensor_name, weights_mapper tensor_name, weights_mapper
) )
if module_name not in loras: if module_name not in loras:
...@@ -160,13 +159,7 @@ class LoRAModel: ...@@ -160,13 +159,7 @@ class LoRAModel:
module_name, peft_helper, lora_embeddings_tensor module_name, peft_helper, lora_embeddings_tensor
) )
if is_bias: if is_lora_a:
loras[module_name].bias = tensor.to(device=device, dtype=dtype)
bias = tensor.to(device=device, dtype=dtype)
if pin_memory:
bias = bias.pin_memory()
loras[module_name].bias = bias
elif is_lora_a:
loras[module_name].lora_a = tensor.to(device=device, dtype=dtype) loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
if pin_memory: if pin_memory:
loras[module_name].lora_a = loras[module_name].lora_a.pin_memory() loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
...@@ -234,9 +227,7 @@ class LoRAModel: ...@@ -234,9 +227,7 @@ class LoRAModel:
def check_unexpected_modules(modules: dict): def check_unexpected_modules(modules: dict):
for lora_module in modules.keys(): # noqa for lora_module in modules.keys(): # noqa
module_name, _, _ = parse_fine_tuned_lora_name( module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
lora_module, weights_mapper
)
part_name = module_name.split(".")[-1] part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules: if part_name not in expected_lora_modules:
unexpected_modules.append(module_name) unexpected_modules.append(module_name)
...@@ -439,23 +430,11 @@ class LoRAModelManager: ...@@ -439,23 +430,11 @@ class LoRAModelManager:
module_lora = self._get_lora_layer_weights(lora_model, module_name) module_lora = self._get_lora_layer_weights(lora_model, module_name)
if module_lora: if module_lora:
module_lora.optimize() module_lora.optimize()
# Bias is not explicitly enabled with the flag enable_lora_bias.
bias = module_lora.bias
if (
torch.is_tensor(bias)
or (isinstance(bias, Sequence) and any(b is not None for b in bias))
) and not self.lora_config.bias_enabled:
module_lora.bias = None
raise ValueError(
f"Adapter bias cannot be used for {module_name}"
" without --enable-lora-bias."
)
module.set_lora( module.set_lora(
index, index,
module_lora.lora_a, module_lora.lora_a,
module_lora.lora_b, module_lora.lora_b,
module_lora.embeddings_tensor, module_lora.embeddings_tensor,
module_lora.bias,
) )
else: else:
module.reset_lora(index) module.reset_lora(index)
...@@ -581,7 +560,6 @@ class LoRAModelManager: ...@@ -581,7 +560,6 @@ class LoRAModelManager:
"""Create zero-initialized LoRAModel for warmup.""" """Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}) model = LoRAModel(lora_id, rank, {})
for module_name, module in self.model.named_modules(): for module_name, module in self.model.named_modules():
bias_enabled = self.lora_config.bias_enabled
if ( if (
not self._match_target_modules(module_name) not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA) or not isinstance(module, BaseLayerWithLoRA)
...@@ -616,7 +594,6 @@ class LoRAModelManager: ...@@ -616,7 +594,6 @@ class LoRAModelManager:
module.lora_a_stacked[0].dtype, module.lora_a_stacked[0].dtype,
"cpu", "cpu",
embeddings_tensor_dim=embeddings_tensor_dim, embeddings_tensor_dim=embeddings_tensor_dim,
bias_enabled=bias_enabled,
) )
else: else:
lora = LoRALayerWeights.create_dummy_lora_weights( lora = LoRALayerWeights.create_dummy_lora_weights(
...@@ -626,7 +603,6 @@ class LoRAModelManager: ...@@ -626,7 +603,6 @@ class LoRAModelManager:
rank, rank,
module.lora_a_stacked[0].dtype, module.lora_a_stacked[0].dtype,
"cpu", "cpu",
bias_enabled=bias_enabled,
) )
else: else:
parts = module_name.split(".") parts = module_name.split(".")
...@@ -640,7 +616,6 @@ class LoRAModelManager: ...@@ -640,7 +616,6 @@ class LoRAModelManager:
rank, rank,
module.lora_a_stacked[i].dtype, module.lora_a_stacked[i].dtype,
"cpu", "cpu",
bias_enabled=bias_enabled,
) )
subloras.append(lora) subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras) lora = PackedLoRALayerWeights.pack(subloras)
......
...@@ -29,7 +29,7 @@ class PEFTHelper: ...@@ -29,7 +29,7 @@ class PEFTHelper:
lora_alpha: int lora_alpha: int
target_modules: Union[list[str], str] target_modules: Union[list[str], str]
bias: Literal["none", "all", "lora_only"] = field(default="none") bias: Literal["none"] = field(default="none")
modules_to_save: Optional[list[str]] = field(default=None) modules_to_save: Optional[list[str]] = field(default=None)
# True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732) # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
use_rslora: bool = field(default=False) use_rslora: bool = field(default=False)
...@@ -122,7 +122,7 @@ class PEFTHelper: ...@@ -122,7 +122,7 @@ class PEFTHelper:
f"LoRA rank {self.r} is greater than max_lora_rank" f"LoRA rank {self.r} is greater than max_lora_rank"
f" {lora_config.max_lora_rank}." f" {lora_config.max_lora_rank}."
) )
if self.bias != "none" and not lora_config.bias_enabled: if self.bias != "none":
error_msg.append("Adapter bias cannot be used without bias_enabled.") error_msg.append("Adapter bias is not supported.")
if error_msg: if error_msg:
raise ValueError(f"{' '.join(error_msg)}") raise ValueError(f"{' '.join(error_msg)}")
...@@ -60,14 +60,13 @@ class PunicaWrapperABC(ABC): ...@@ -60,14 +60,13 @@ class PunicaWrapperABC(ABC):
y: torch.Tensor, y: torch.Tensor,
x: Union[tuple[torch.Tensor, ...], torch.Tensor], x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM for multiple slices of lora_b.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -93,7 +92,6 @@ class PunicaWrapperABC(ABC): ...@@ -93,7 +92,6 @@ class PunicaWrapperABC(ABC):
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
*, *,
...@@ -222,38 +220,6 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -222,38 +220,6 @@ class PunicaWrapperBase(PunicaWrapperABC):
self.token_nums = token_nums self.token_nums = token_nums
self.no_lora = no_lora self.no_lora = no_lora
def _apply_bias(
self,
indices: torch.Tensor,
output: torch.Tensor,
output_slices: tuple[int, ...],
lora_bias_stacked: tuple[Optional[torch.Tensor], ...],
):
"""Applies bias to output
Input shapes:
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = lora_bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias[indices == -1] = 0
output[:, offset_left : offset_left + slice] += bias
offset_left += slice
return output.view_as(org_output)
@property @property
def prefill_metadata( def prefill_metadata(
self, self,
...@@ -365,29 +331,25 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -365,29 +331,25 @@ class PunicaWrapperBase(PunicaWrapperABC):
y: torch.Tensor, y: torch.Tensor,
x: Union[tuple[torch.Tensor, ...], torch.Tensor], x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM for multiple slices of lora_b.
Semantics: Semantics:
offset = offset_start offset = offset_start
for i in range(len(lora_b_stacked)): for i in range(len(lora_b_stacked)):
slice = output_slices[i] slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
lora_bias_stacked[i]
offset += slice offset += slice
Args: Args:
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (tuple[int, ...]): Every slice's size output_slices (tuple[int, ...]): Every slice's size
offset_start (int): The starting position of y, defaults to 0 offset_start (int): The starting position of y, defaults to 0
add_inputs (bool): Defaults to True. add_inputs (bool): Defaults to True.
...@@ -427,7 +389,6 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -427,7 +389,6 @@ class PunicaWrapperBase(PunicaWrapperABC):
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
*, *,
...@@ -444,14 +405,13 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -444,14 +405,13 @@ class PunicaWrapperBase(PunicaWrapperABC):
@ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :]
* scale * scale
).squeeze(0)+lora_bias_stacked[i] ).squeeze(0)
Args: Args:
y (torch.Tensor): Output tensor. Will be changed in-place. y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor. scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size. output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
......
...@@ -199,38 +199,30 @@ class PunicaWrapperCPU(PunicaWrapperBase): ...@@ -199,38 +199,30 @@ class PunicaWrapperCPU(PunicaWrapperBase):
y: torch.Tensor, y: torch.Tensor,
x: Union[tuple[torch.Tensor, ...], torch.Tensor], x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM for multiple slices of lora_b.
Semantics: Semantics:
for i in range(len(lora_b_stacked)): for i in range(len(lora_b_stacked)):
slice = output_slices[i] slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
lora_bias_stacked[i]
offset += slice offset += slice
Args: Args:
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (tuple[int, ...]): Every slice's size output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True. add_inputs (bool): Defaults to True.
""" """
y_org = y y_org = y
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
offset_left = offset_start offset_left = offset_start
if lora_bias_stacked is not None:
self._apply_bias(
self.token_lora_indices, y, output_slices, lora_bias_stacked
)
for slice_idx in range(len(lora_b_stacked)): for slice_idx in range(len(lora_b_stacked)):
self._apply_expand( self._apply_expand(
y, y,
...@@ -276,7 +268,6 @@ class PunicaWrapperCPU(PunicaWrapperBase): ...@@ -276,7 +268,6 @@ class PunicaWrapperCPU(PunicaWrapperBase):
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
*, *,
...@@ -293,25 +284,19 @@ class PunicaWrapperCPU(PunicaWrapperBase): ...@@ -293,25 +284,19 @@ class PunicaWrapperCPU(PunicaWrapperBase):
@ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :]
* scale * scale
).squeeze(0)+lora_bias_stacked[i] ).squeeze(0)
Args: Args:
y (torch.Tensor): Output tensor. Will be changed in-place. y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor. scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size. output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
""" """
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(
self.token_lora_indices, y, output_slices, lora_bias_stacked
)
if buffer is None: if buffer is None:
r = lora_b_stacked[0].size(-1) r = lora_b_stacked[0].size(-1)
...@@ -323,7 +308,7 @@ class PunicaWrapperCPU(PunicaWrapperBase): ...@@ -323,7 +308,7 @@ class PunicaWrapperCPU(PunicaWrapperBase):
) )
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand( self.add_expand(
y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
) )
def add_lora_logits( def add_lora_logits(
......
...@@ -101,36 +101,29 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -101,36 +101,29 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM for multiple slices of lora_b.
Semantics: Semantics:
for i in range(len(lora_b_stacked)): for i in range(len(lora_b_stacked)):
slice = output_slices[i] slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
lora_bias_stacked[i]
offset += slice offset += slice
Args: Args:
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensors x (torch.Tensor): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (tuple[int, ...]): Every slice's size output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True. add_inputs (bool): Defaults to True.
""" """
y_org = y y_org = y
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
if lora_bias_stacked is not None:
token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, y.size(0))
self._apply_bias(token_lora_indices, y, output_slices, lora_bias_stacked)
assert x.ndim == 3 assert x.ndim == 3
assert x.size(0) == len(output_slices) assert x.size(0) == len(output_slices)
...@@ -183,7 +176,6 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -183,7 +176,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
*, *,
...@@ -200,26 +192,18 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -200,26 +192,18 @@ class PunicaWrapperGPU(PunicaWrapperBase):
@ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :]
* scale * scale
).squeeze(0)+lora_bias_stacked[i] ).squeeze(0)
Args: Args:
y (torch.Tensor): Output tensor. Will be changed in-place. y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor. scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size. output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[torch.Tensor]): Defaults to None. buffer (Optional[torch.Tensor]): Defaults to None.
""" """
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, y.size(0))
y = self._apply_bias(
token_lora_indices, y, output_slices, lora_bias_stacked
)
if buffer is None: if buffer is None:
r = lora_b_stacked[0].size(-1) r = lora_b_stacked[0].size(-1)
...@@ -241,7 +225,6 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -241,7 +225,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y, y,
buffer, # type: ignore buffer, # type: ignore
lora_b_stacked, lora_b_stacked,
None,
output_slices, output_slices,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
......
...@@ -139,28 +139,24 @@ class PunicaWrapperTPU(PunicaWrapperBase): ...@@ -139,28 +139,24 @@ class PunicaWrapperTPU(PunicaWrapperBase):
y: torch.Tensor, y: torch.Tensor,
x: Union[tuple[torch.Tensor, ...], torch.Tensor], x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM for multiple slices of lora_b.
Semantics: Semantics:
for i in range(len(lora_b_stacked)): for i in range(len(lora_b_stacked)):
slice = output_slices[i] slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
lora_bias_stacked[i]
offset += slice offset += slice
Args: Args:
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (tuple[int, ...]): Every slice's size output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True. add_inputs (bool): Defaults to True.
""" """
...@@ -168,10 +164,6 @@ class PunicaWrapperTPU(PunicaWrapperBase): ...@@ -168,10 +164,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
offset_left = 0 offset_left = 0
if lora_bias_stacked is not None:
y = self._apply_bias(
self._get_token_lora_indices(y), y, output_slices, lora_bias_stacked
)
for slice_idx in range(len(lora_b_stacked)): for slice_idx in range(len(lora_b_stacked)):
y = self.expand_slice( y = self.expand_slice(
y, y,
...@@ -214,7 +206,6 @@ class PunicaWrapperTPU(PunicaWrapperBase): ...@@ -214,7 +206,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
*, *,
...@@ -231,25 +222,19 @@ class PunicaWrapperTPU(PunicaWrapperBase): ...@@ -231,25 +222,19 @@ class PunicaWrapperTPU(PunicaWrapperBase):
@ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :]
* scale * scale
).squeeze(0)+lora_bias_stacked[i] ).squeeze(0)
Args: Args:
y (torch.Tensor): Output tensor. Will not be changed in-place. y (torch.Tensor): Output tensor. Will not be changed in-place.
x (torch.Tensor): Input tensor (T, E) x (torch.Tensor): Input tensor (T, E)
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor. scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size. output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
""" """
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(
self._get_token_lora_indices(y), y, output_slices, lora_bias_stacked
)
if buffer is None: if buffer is None:
r = lora_b_stacked[0].size(-1) r = lora_b_stacked[0].size(-1)
...@@ -261,7 +246,7 @@ class PunicaWrapperTPU(PunicaWrapperBase): ...@@ -261,7 +246,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
) )
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
return self.add_expand( return self.add_expand(
y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
) )
def add_lora_logits( def add_lora_logits(
...@@ -299,43 +284,6 @@ class PunicaWrapperTPU(PunicaWrapperBase): ...@@ -299,43 +284,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True) y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
return y.view_as(y_org) return y.view_as(y_org)
def _apply_bias(
self,
indices: torch.Tensor,
output: torch.Tensor,
output_slices: tuple[int, ...],
lora_bias_stacked: tuple[Optional[torch.Tensor], ...],
):
"""Applies bias to output
Input shapes:
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = lora_bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias = torch.where(indices[:, None] == -1, 0, bias)
bias = F.pad(
bias, (offset_left, output.shape[1] - (offset_left + slice), 0, 0)
)
output += bias
offset_left += slice
return output.view_as(org_output)
# This performs the same tensor ops as the base method, except it does them # This performs the same tensor ops as the base method, except it does them
# on the CPU then transfers the results to the TPU # on the CPU then transfers the results to the TPU
def _update_base_metadata( def _update_base_metadata(
......
...@@ -108,36 +108,29 @@ class PunicaWrapperXPU(PunicaWrapperBase): ...@@ -108,36 +108,29 @@ class PunicaWrapperXPU(PunicaWrapperBase):
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM for multiple slices of lora_b.
Semantics: Semantics:
for i in range(len(lora_b_stacked)): for i in range(len(lora_b_stacked)):
slice = output_slices[i] slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
lora_bias_stacked[i]
offset += slice offset += slice
Args: Args:
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensors x (torch.Tensor): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (tuple[int, ...]): Every slice's size output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True. add_inputs (bool): Defaults to True.
""" """
y_org = y y_org = y
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
if lora_bias_stacked is not None:
token_lora_indices = self._get_token_lora_indices(y)
self._apply_bias(token_lora_indices, y, output_slices, lora_bias_stacked)
assert x.ndim == 3 assert x.ndim == 3
assert x.size(0) == len(output_slices) assert x.size(0) == len(output_slices)
...@@ -184,7 +177,6 @@ class PunicaWrapperXPU(PunicaWrapperBase): ...@@ -184,7 +177,6 @@ class PunicaWrapperXPU(PunicaWrapperBase):
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
*, *,
...@@ -201,26 +193,19 @@ class PunicaWrapperXPU(PunicaWrapperBase): ...@@ -201,26 +193,19 @@ class PunicaWrapperXPU(PunicaWrapperBase):
@ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :]
* scale * scale
).squeeze(0)+lora_bias_stacked[i] ).squeeze(0)
Args: Args:
y (torch.Tensor): Output tensor. Will be changed in-place. y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor. scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size. output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[torch.Tensor]): Defaults to None. buffer (Optional[torch.Tensor]): Defaults to None.
""" """
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
token_lora_indices = self._get_token_lora_indices(y)
y = self._apply_bias(
token_lora_indices, y, output_slices, lora_bias_stacked
)
if buffer is None: if buffer is None:
r = lora_b_stacked[0].size(-1) r = lora_b_stacked[0].size(-1)
...@@ -242,7 +227,6 @@ class PunicaWrapperXPU(PunicaWrapperBase): ...@@ -242,7 +227,6 @@ class PunicaWrapperXPU(PunicaWrapperBase):
y, y,
buffer, # type: ignore buffer, # type: ignore
lora_b_stacked, lora_b_stacked,
None,
output_slices, output_slices,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
......
...@@ -112,7 +112,7 @@ def replace_submodule( ...@@ -112,7 +112,7 @@ def replace_submodule(
def parse_fine_tuned_lora_name( def parse_fine_tuned_lora_name(
name: str, weights_mapper: Optional["WeightsMapper"] = None name: str, weights_mapper: Optional["WeightsMapper"] = None
) -> tuple[str, bool, bool]: ) -> tuple[str, bool]:
"""Parse the name of lora weights. """Parse the name of lora weights.
args: args:
...@@ -124,7 +124,6 @@ def parse_fine_tuned_lora_name( ...@@ -124,7 +124,6 @@ def parse_fine_tuned_lora_name(
tuple(module_name, is_lora_a): tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1, module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b. is_lora_a whether the tensor is lora_a or lora_b.
is_bias whether the tensor is lora bias.
""" """
# LoRA weight qualified name usually starts with `base_model.model.`, # LoRA weight qualified name usually starts with `base_model.model.`,
...@@ -146,15 +145,11 @@ def parse_fine_tuned_lora_name( ...@@ -146,15 +145,11 @@ def parse_fine_tuned_lora_name(
parts = name.split(".") parts = name.split(".")
if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"): if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"):
new_name = ".".join(parts[start_index:-2]) new_name = ".".join(parts[start_index:-2])
return new_name, parts[-2] == "lora_A", False return new_name, parts[-2] == "lora_A"
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
new_name = ".".join(parts[start_index:-1]) new_name = ".".join(parts[start_index:-1])
return new_name, parts[-1] == "lora_embedding_A", False return new_name, parts[-1] == "lora_embedding_A"
if parts[-1] == "bias":
new_name = ".".join(parts[start_index:-2])
return new_name, False, True
raise ValueError(f"{name} is unsupported LoRA weight") raise ValueError(f"{name} is unsupported LoRA weight")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment