Unverified Commit 8a06428c authored by Umesh's avatar Umesh Committed by GitHub
Browse files

[LoRA] Adds support for bias in LoRA (#5733)


Signed-off-by: default avatarUmesh Deshpande <udeshpa@us.ibm.com>
Co-authored-by: default avatarUmesh Deshpande <udeshpa@us.ibm.com>
parent b41fb9d3
...@@ -152,6 +152,11 @@ def sql_lora_files(sql_lora_huggingface_id): ...@@ -152,6 +152,11 @@ def sql_lora_files(sql_lora_huggingface_id):
return snapshot_download(repo_id=sql_lora_huggingface_id) return snapshot_download(repo_id=sql_lora_huggingface_id)
@pytest.fixture(scope="session")
def lora_bias_files():
return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def mixtral_lora_files(): def mixtral_lora_files():
# Note: this module has incorrect adapter_config.json to test # Note: this module has incorrect adapter_config.json to test
......
from typing import List
import pytest
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "ibm-granite/granite-3b-code-base"
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256,
stop=["[/assistant]"])
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
generated_texts: List[str] = []
for output in outputs:
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
return generated_texts
@pytest.mark.parametrize("lora_bias", [True])
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_lora_rank=8,
max_loras=1,
enable_lora_bias=lora_bias,
tensor_parallel_size=1,
fully_sharded_loras=fully_sharded)
print("lora adapter created")
output1 = do_sample(llm, lora_bias_files, lora_id=0)
print("lora")
output2 = do_sample(llm, lora_bias_files, lora_id=1)
if lora_bias:
assert output1 != output2
else:
assert output1 == output2
...@@ -12,36 +12,40 @@ from vllm.utils import LRUCache ...@@ -12,36 +12,40 @@ from vllm.utils import LRUCache
def test_parse_fine_tuned_lora_name_valid(): def test_parse_fine_tuned_lora_name_valid():
fixture = { fixture = {
("base_model.model.lm_head.lora_A.weight", "lm_head", True), ("base_model.model.lm_head.lora_A.weight", "lm_head", True, False),
("base_model.model.lm_head.lora_B.weight", "lm_head", False), ("base_model.model.lm_head.lora_B.weight", "lm_head", False, False),
( (
"base_model.model.model.embed_tokens.lora_embedding_A", "base_model.model.model.embed_tokens.lora_embedding_A",
"model.embed_tokens", "model.embed_tokens",
True, True,
False,
), ),
( (
"base_model.model.model.embed_tokens.lora_embedding_B", "base_model.model.model.embed_tokens.lora_embedding_B",
"model.embed_tokens", "model.embed_tokens",
False, False,
False,
), ),
( (
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"model.layers.9.mlp.down_proj", "model.layers.9.mlp.down_proj",
True, True,
False,
), ),
( (
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"model.layers.9.mlp.down_proj", "model.layers.9.mlp.down_proj",
False, False,
False,
), ),
} }
for name, module_name, is_lora_a in fixture: for name, module_name, is_lora_a, is_bias in fixture:
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name) assert (module_name, is_lora_a,
is_bias) == parse_fine_tuned_lora_name(name)
def test_parse_fine_tuned_lora_name_invalid(): def test_parse_fine_tuned_lora_name_invalid():
fixture = { fixture = {
"weight",
"base_model.weight", "base_model.weight",
"base_model.model.weight", "base_model.model.weight",
} }
......
...@@ -1687,6 +1687,7 @@ class LoRAConfig: ...@@ -1687,6 +1687,7 @@ class LoRAConfig:
# This is a constant. # This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256 lora_vocab_padding_size: ClassVar[int] = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None long_lora_scaling_factors: Optional[Tuple[float]] = None
bias_enabled: bool = False
def __post_init__(self): def __post_init__(self):
# Setting the maximum rank to 256 should be able to satisfy the vast # Setting the maximum rank to 256 should be able to satisfy the vast
......
...@@ -143,6 +143,7 @@ class EngineArgs: ...@@ -143,6 +143,7 @@ class EngineArgs:
limit_mm_per_prompt: Optional[Mapping[str, int]] = None limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None
enable_lora: bool = False enable_lora: bool = False
enable_lora_bias: bool = False
max_loras: int = 1 max_loras: int = 1
max_lora_rank: int = 16 max_lora_rank: int = 16
enable_prompt_adapter: bool = False enable_prompt_adapter: bool = False
...@@ -584,6 +585,9 @@ class EngineArgs: ...@@ -584,6 +585,9 @@ class EngineArgs:
parser.add_argument('--enable-lora', parser.add_argument('--enable-lora',
action='store_true', action='store_true',
help='If True, enable handling of LoRA adapters.') help='If True, enable handling of LoRA adapters.')
parser.add_argument('--enable-lora-bias',
action='store_true',
help='If True, enable bias for LoRA adapters.')
parser.add_argument('--max-loras', parser.add_argument('--max-loras',
type=int, type=int,
default=EngineArgs.max_loras, default=EngineArgs.max_loras,
...@@ -1148,6 +1152,7 @@ class EngineArgs: ...@@ -1148,6 +1152,7 @@ class EngineArgs:
and parallel_config.use_ray), and parallel_config.use_ray),
policy=self.scheduling_policy) policy=self.scheduling_policy)
lora_config = LoRAConfig( lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras, max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras, fully_sharded_loras=self.fully_sharded_loras,
......
...@@ -70,6 +70,14 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): ...@@ -70,6 +70,14 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
self.lora_b_stacked, self.lora_b_stacked,
add_input=True) add_input=True)
# now have column partitioned output # now have column partitioned output
if self.bias_stacked is not None:
self.bias_stacked = self.bias_stacked.view(
-1, self.bias_stacked.shape[-1])
self.bias_stacked = self.bias_stacked[
self.punica_wrapper.token_lora_indices]
output += self.bias_stacked
output = output.view(*out_orig_shape) output = output.view(*out_orig_shape)
return output return output
...@@ -121,6 +129,15 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora): ...@@ -121,6 +129,15 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
left_offset = 0 left_offset = 0
for idx in range(n): for idx in range(n):
shard_size = layer.lora_b_stacked[idx].shape[2] shard_size = layer.lora_b_stacked[idx].shape[2]
if layer.bias_stacked is not None:
bias = layer.bias_stacked[idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[layer.punica_wrapper.token_lora_indices]
bias[layer.punica_wrapper.token_lora_indices == -1] = 0
output[:, left_offset:left_offset + shard_size] += bias
layer.punica_wrapper.add_expand_slice( layer.punica_wrapper.add_expand_slice(
output, output,
buffers[idx], buffers[idx],
...@@ -295,6 +312,15 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): ...@@ -295,6 +312,15 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
lora_b = lora_b[:, start_idx:end_idx] lora_b = lora_b[:, start_idx:end_idx]
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None:
return bias
shard_size = self.bias_stacked.shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias
def apply(self, x: torch.Tensor) -> torch.Tensor: def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x) output = self.base_layer.quant_method.apply(self.base_layer, x)
...@@ -318,6 +344,13 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): ...@@ -318,6 +344,13 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
# reduced before being used # reduced before being used
shard_size = self.lora_b_stacked.shape[2] shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size start_idx = self.tp_rank * shard_size
if self.bias_stacked is not None:
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
bias = bias[self.punica_wrapper.token_lora_indices]
bias[self.punica_wrapper.token_lora_indices == -1] = 0
output += bias
self.punica_wrapper.add_expand_slice(output, buffer, self.punica_wrapper.add_expand_slice(output, buffer,
self.lora_b_stacked, start_idx, self.lora_b_stacked, start_idx,
shard_size) shard_size)
......
...@@ -67,6 +67,63 @@ def _not_fully_sharded_can_replace(can_replace): ...@@ -67,6 +67,63 @@ def _not_fully_sharded_can_replace(can_replace):
return dec return dec
def apply_bias(
indices: torch.Tensor,
output: torch.Tensor,
bias_stacked: torch.Tensor,
):
"""Applies bias to output
Input shapes:
bias_stacked: (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, output_dim)
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
bias_stacked = bias_stacked[indices]
bias_stacked[indices == -1] = 0
output += bias_stacked
return output.view_as(org_output)
def apply_bias_packed_nslice(
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
):
"""Applies bias to output
Input shapes:
bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias[indices == -1] = 0
output[:, offset_left:offset_left + slice] += bias
offset_left += slice
return output.view_as(org_output)
@dataclass @dataclass
class LoRAMapping(AdapterMapping): class LoRAMapping(AdapterMapping):
is_prefill: bool = False is_prefill: bool = False
...@@ -105,6 +162,7 @@ class BaseLayerWithLoRA(nn.Module): ...@@ -105,6 +162,7 @@ class BaseLayerWithLoRA(nn.Module):
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
"""Overwrites lora tensors at index.""" """Overwrites lora tensors at index."""
... ...
...@@ -203,6 +261,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -203,6 +261,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
self.reset_lora(index) self.reset_lora(index)
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
...@@ -299,10 +358,22 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA): ...@@ -299,10 +358,22 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) )
if lora_config.bias_enabled:
self.bias_stacked = torch.zeros(
max_loras,
1,
self.output_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
else:
self.bias_stacked = None
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0 self.lora_b_stacked[index] = 0
if self.lora_config.bias_enabled:
self.bias_stacked[index] = 0
def set_lora( def set_lora(
self, self,
...@@ -310,6 +381,7 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA): ...@@ -310,6 +381,7 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
self.reset_lora(index) self.reset_lora(index)
...@@ -319,10 +391,21 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA): ...@@ -319,10 +391,21 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[index, self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_( 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True) lora_b.T, non_blocking=True)
if bias is not None:
self.bias_stacked[index,
0, :bias.shape[0]].copy_(bias.T,
non_blocking=True)
def apply(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias) output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0) self.lora_b_stacked, 1.0)
return output return output
...@@ -401,11 +484,25 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -401,11 +484,25 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) )
if lora_config.bias_enabled:
self.bias_stacked = torch.zeros(
max_loras,
1,
self.output_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
else:
self.bias_stacked = None
self.output_dim = self.lora_b_stacked.shape[2] self.output_dim = self.lora_b_stacked.shape[2]
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0 self.lora_b_stacked[index] = 0
if self.lora_config.bias_enabled:
self.bias_stacked[index] = 0
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
return lora_a return lora_a
...@@ -418,18 +515,30 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -418,18 +515,30 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
lora_b = lora_b[:, start_idx:end_idx] lora_b = lora_b[:, start_idx:end_idx]
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None:
return bias
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias
def set_lora( def set_lora(
self, self,
index: int, index: int,
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
self.reset_lora(index) self.reset_lora(index)
if self.tp_size > 1: if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a) lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b) lora_b = self.slice_lora_b(lora_b)
bias = self.slice_bias(bias)
self.lora_a_stacked[index, self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_( 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
...@@ -437,10 +546,21 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -437,10 +546,21 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[index, self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_( 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True) lora_b.T, non_blocking=True)
if bias is not None:
self.bias_stacked[index,
0, :bias.shape[0]].copy_(bias.T,
non_blocking=True)
def apply(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias) output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0) self.lora_b_stacked, 1.0)
return output return output
...@@ -534,6 +654,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -534,6 +654,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) for _ in range(n_slices)) ) for _ in range(n_slices))
if lora_config.bias_enabled:
self.bias_stacked = tuple(
torch.zeros(
max_loras,
1,
self.output_size // 2,
dtype=lora_config.lora_dtype,
device=self.device,
) for _ in range(n_slices))
else:
self.bias_stacked = None
self.output_dim = self.lora_b_stacked[0].shape[2] self.output_dim = self.lora_b_stacked[0].shape[2]
...@@ -542,6 +673,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -542,6 +673,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.lora_a_stacked[1][index] = 0 self.lora_a_stacked[1][index] = 0
self.lora_b_stacked[0][index] = 0 self.lora_b_stacked[0][index] = 0
self.lora_b_stacked[1][index] = 0 self.lora_b_stacked[1][index] = 0
if self.lora_config.bias_enabled:
self.bias_stacked[0][index] = 0
self.bias_stacked[1][index] = 0
def slice_lora_a( def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]] self, lora_a: List[Union[torch.Tensor, None]]
...@@ -562,18 +696,32 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -562,18 +696,32 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
] ]
return lora_b return lora_b
def slice_bias(
self, bias: List[Union[torch.Tensor,
None]]) -> List[Union[torch.Tensor, None]]:
if bias[0] is None or bias[1] is None:
return bias
shard_size = self.output_dim
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = [bias[0][start_idx:end_idx], bias[1][start_idx:end_idx]]
return bias
def set_lora( def set_lora(
self, self,
index: int, index: int,
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
self.reset_lora(index) self.reset_lora(index)
if self.tp_size > 1: if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a) lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b) lora_b = self.slice_lora_b(lora_b)
if bias is not None:
bias = self.slice_bias(bias)
if lora_a[0] is not None: if lora_a[0] is not None:
self.lora_a_stacked[0][ self.lora_a_stacked[0][
...@@ -582,6 +730,10 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -582,6 +730,10 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.lora_b_stacked[0][ self.lora_b_stacked[0][
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
lora_b[0].T, non_blocking=True) lora_b[0].T, non_blocking=True)
if bias is not None and bias[0] is not None:
self.bias_stacked[0][index,
0, :bias[0].shape[0]].copy_(bias[0].T,
non_blocking=True)
if lora_a[1] is not None: if lora_a[1] is not None:
self.lora_a_stacked[1][ self.lora_a_stacked[1][
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
...@@ -589,10 +741,22 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -589,10 +741,22 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.lora_b_stacked[1][ self.lora_b_stacked[1][
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True) lora_b[1].T, non_blocking=True)
if bias is not None and bias[1] is not None:
self.bias_stacked[1][index,
0, :bias[1].shape[0]].copy_(bias[1].T,
non_blocking=True)
def apply(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias) output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias_packed_nslice(
self.indices,
output,
(self.output_dim, self.output_dim),
self.bias_stacked,
)
self.punica_wrapper.add_lora_packed_nslice( self.punica_wrapper.add_lora_packed_nslice(
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
(self.output_dim, self.output_dim)) (self.output_dim, self.output_dim))
...@@ -654,17 +818,35 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -654,17 +818,35 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
bias_q = bias[self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
bias_k = bias[k_offset +
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
bias_v = bias[v_offset +
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
return bias
def set_lora( def set_lora(
self, self,
index: int, index: int,
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
self.reset_lora(index) self.reset_lora(index)
if self.tp_size > 1: if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a) lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b) lora_b = self.slice_lora_b(lora_b)
if bias is not None:
bias = self.slice_bias(bias)
self.lora_a_stacked[index, self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_( 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
...@@ -672,6 +854,10 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -672,6 +854,10 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.lora_b_stacked[index, self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_( 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True) lora_b.T, non_blocking=True)
if bias is not None:
self.bias_stacked[index,
0, :bias.shape[0]].copy_(bias.T,
non_blocking=True)
@classmethod @classmethod
@_not_fully_sharded_can_replace @_not_fully_sharded_can_replace
...@@ -768,6 +954,32 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -768,6 +954,32 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
device=self.device, device=self.device,
), ),
) )
if lora_config.bias_enabled:
self.bias_stacked = (
torch.zeros(
max_loras,
1,
self.q_proj_shard_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
)
else:
self.bias_stacked = None
self.output_slices = ( self.output_slices = (
self.q_proj_shard_size, self.q_proj_shard_size,
...@@ -787,6 +999,10 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -787,6 +999,10 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.lora_b_stacked[1][index] = 0 self.lora_b_stacked[1][index] = 0
self.lora_a_stacked[2][index] = 0 self.lora_a_stacked[2][index] = 0
self.lora_b_stacked[2][index] = 0 self.lora_b_stacked[2][index] = 0
if self.lora_config.bias_enabled:
self.bias_stacked[0][index] = 0
self.bias_stacked[1][index] = 0
self.bias_stacked[2][index] = 0
def slice_lora_a( def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]] self, lora_a: List[Union[torch.Tensor, None]]
...@@ -812,18 +1028,40 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -812,18 +1028,40 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
lora_b = [lora_b_q, lora_b_k, lora_b_v] lora_b = [lora_b_q, lora_b_k, lora_b_v]
return lora_b return lora_b
def slice_bias(
self, bias: List[Union[torch.Tensor,
None]]) -> List[Union[torch.Tensor, None]]:
bias_q, bias_k, bias_v = bias
if bias_q is not None:
bias_q = bias_q[self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
if bias_k is not None:
bias_k = bias_k[self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
if bias_v is not None:
bias_v = bias_v[self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
bias = [bias_q, bias_k, bias_v]
return bias
def set_lora( def set_lora(
self, self,
index: int, index: int,
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
self.reset_lora(index) self.reset_lora(index)
if self.tp_size > 1: if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a) lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b) lora_b = self.slice_lora_b(lora_b)
if bias is not None:
bias = self.slice_bias(bias)
if lora_b[0] is not None: if lora_b[0] is not None:
lora_b_q = lora_b[0] lora_b_q = lora_b[0]
...@@ -854,9 +1092,28 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -854,9 +1092,28 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
lora_a[2].T, non_blocking=True) lora_a[2].T, non_blocking=True)
if bias is not None:
if bias[0] is not None:
self.bias_stacked[0][index, 0, :bias[0].shape[0]].copy_(
bias[0].T, non_blocking=True)
if bias[1] is not None:
self.bias_stacked[1][index, 0, :bias[1].shape[0]].copy_(
bias[1].T, non_blocking=True)
if bias[2] is not None:
self.bias_stacked[2][index, 0, :bias[2].shape[0]].copy_(
bias[2].T, non_blocking=True)
def apply(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias) output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias_packed_nslice(
self.indices,
output,
self.output_slices,
self.bias_stacked,
)
self.punica_wrapper.add_lora_packed_nslice(output, x, self.punica_wrapper.add_lora_packed_nslice(output, x,
self.lora_a_stacked, self.lora_a_stacked,
self.lora_b_stacked, 1.0, self.lora_b_stacked, 1.0,
...@@ -919,9 +1176,27 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -919,9 +1176,27 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
device=self.device, device=self.device,
) )
if lora_config.bias_enabled:
self.bias_stacked = torch.zeros(
(
max_loras,
1,
self.output_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
else:
self.bias_stacked = None
# Lazily initialized
self.indices: torch.Tensor
self.indices_len: List[int]
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0 self.lora_b_stacked[index] = 0
if self.lora_config.bias_enabled:
self.bias_stacked[index] = 0
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
...@@ -934,18 +1209,24 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -934,18 +1209,24 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
return bias
def set_lora( def set_lora(
self, self,
index: int, index: int,
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
self.reset_lora(index) self.reset_lora(index)
if self.base_layer.tp_size > 1: if self.base_layer.tp_size > 1:
lora_a = self.slice_lora_a(lora_a) lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b) lora_b = self.slice_lora_b(lora_b)
if bias is not None:
bias = self.slice_bias(bias)
self.lora_a_stacked[index, self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_( 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
...@@ -953,9 +1234,20 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -953,9 +1234,20 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[index, self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_( 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True) lora_b.T, non_blocking=True)
if bias is not None:
self.bias_stacked[index,
0, :bias.shape[0]].copy_(bias.T,
non_blocking=True)
def apply(self, x: torch.Tensor) -> torch.Tensor: def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x) output = self.base_layer.quant_method.apply(self.base_layer, x)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0) self.lora_b_stacked, 1.0)
return output return output
...@@ -1132,6 +1424,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1132,6 +1424,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
self.reset_lora(index) self.reset_lora(index)
self.lora_a_stacked[index, self.lora_a_stacked[index,
...@@ -1199,7 +1492,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1199,7 +1492,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
neginf=float("-inf"))) neginf=float("-inf")))
logits[:, logits[:,
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1], ] = lora_logits lora_logits.shape[1]] = lora_logits
# LogitsProcessorWithLoRA always using bgmv # LogitsProcessorWithLoRA always using bgmv
self.punica_wrapper.add_lora_logits(logits, hidden_states, self.punica_wrapper.add_lora_logits(logits, hidden_states,
...@@ -1276,6 +1569,7 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): ...@@ -1276,6 +1569,7 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
... ...
......
...@@ -17,6 +17,7 @@ class LoRALayerWeights: ...@@ -17,6 +17,7 @@ class LoRALayerWeights:
lora_alpha: int, lora_alpha: int,
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
bias: Optional[torch.Tensor] = None,
embeddings_tensor: Optional[torch.Tensor] = None, embeddings_tensor: Optional[torch.Tensor] = None,
scaling: Optional[float] = None, scaling: Optional[float] = None,
) -> None: ) -> None:
...@@ -25,6 +26,7 @@ class LoRALayerWeights: ...@@ -25,6 +26,7 @@ class LoRALayerWeights:
self.lora_alpha = lora_alpha self.lora_alpha = lora_alpha
self.lora_a = lora_a self.lora_a = lora_a
self.lora_b = lora_b self.lora_b = lora_b
self.bias = bias
self.embeddings_tensor = embeddings_tensor self.embeddings_tensor = embeddings_tensor
if scaling is None: if scaling is None:
...@@ -66,7 +68,8 @@ class LoRALayerWeights: ...@@ -66,7 +68,8 @@ class LoRALayerWeights:
rank: int, rank: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.types.Device, device: torch.types.Device,
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": embeddings_tensor_dim: Optional[int] = None,
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available() pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank], lora_a = torch.zeros([input_dim, rank],
dtype=dtype, dtype=dtype,
...@@ -76,6 +79,14 @@ class LoRALayerWeights: ...@@ -76,6 +79,14 @@ class LoRALayerWeights:
dtype=dtype, dtype=dtype,
device=device, device=device,
pin_memory=pin_memory) pin_memory=pin_memory)
if bias_enabled:
bias = torch.zeros([output_dim],
dtype=dtype,
device=device,
pin_memory=pin_memory)
else:
bias = None
embeddings_tensor = torch.rand( embeddings_tensor = torch.rand(
10, 10,
embeddings_tensor_dim, embeddings_tensor_dim,
...@@ -88,6 +99,7 @@ class LoRALayerWeights: ...@@ -88,6 +99,7 @@ class LoRALayerWeights:
lora_alpha=1, lora_alpha=1,
lora_a=lora_a, lora_a=lora_a,
lora_b=lora_b, lora_b=lora_b,
bias=bias,
embeddings_tensor=embeddings_tensor, embeddings_tensor=embeddings_tensor,
) )
...@@ -102,6 +114,7 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -102,6 +114,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alphas: List[Optional[int]], lora_alphas: List[Optional[int]],
lora_a: List[Optional[torch.Tensor]], lora_a: List[Optional[torch.Tensor]],
lora_b: List[Optional[torch.Tensor]], lora_b: List[Optional[torch.Tensor]],
bias: Optional[List[Optional[torch.Tensor]]] = None,
scaling: Optional[List[float]] = None, scaling: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__( super().__init__(
...@@ -110,6 +123,7 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -110,6 +123,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alpha=0, lora_alpha=0,
lora_a=lora_a, lora_a=lora_a,
lora_b=lora_b, lora_b=lora_b,
bias=bias,
scaling=scaling, # type: ignore scaling=scaling, # type: ignore
embeddings_tensor=None, embeddings_tensor=None,
) )
...@@ -141,6 +155,7 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -141,6 +155,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
[lora.lora_alpha if lora is not None else None for lora in loras], [lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras], [lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras], [lora.lora_b if lora is not None else None for lora in loras],
[lora.bias if lora is not None else None for lora in loras],
scaling=[ scaling=[
1 if lora is not None else None # type: ignore 1 if lora is not None else None # type: ignore
for lora in loras for lora in loras
......
...@@ -4,7 +4,7 @@ import math ...@@ -4,7 +4,7 @@ import math
import os import os
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Type from typing import Any, Callable, Dict, List, Optional, Sequence, Type
import safetensors.torch import safetensors.torch
import torch import torch
...@@ -119,7 +119,8 @@ class LoRAModel(AdapterModel): ...@@ -119,7 +119,8 @@ class LoRAModel(AdapterModel):
pin_memory = str(device) == "cpu" and is_pin_memory_available() pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: Dict[str, LoRALayerWeights] = {} loras: Dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items(): for tensor_name, tensor in tensors.items():
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
tensor_name)
if module_name not in loras: if module_name not in loras:
lora_embeddings_tensor = None lora_embeddings_tensor = None
if embeddings: if embeddings:
...@@ -136,8 +137,16 @@ class LoRAModel(AdapterModel): ...@@ -136,8 +137,16 @@ class LoRAModel(AdapterModel):
lora_embeddings_tensor.pin_memory()) lora_embeddings_tensor.pin_memory())
loras[module_name] = LoRALayerWeights(module_name, rank, loras[module_name] = LoRALayerWeights(module_name, rank,
lora_alpha, None, None, lora_alpha, None, None,
None,
lora_embeddings_tensor) lora_embeddings_tensor)
if is_lora_a: if is_bias:
loras[module_name].bias = tensor.to(device=device,
dtype=dtype).t()
bias = tensor.to(device=device, dtype=dtype).t()
if pin_memory:
bias = bias.pin_memory()
loras[module_name].bias = bias
elif is_lora_a:
loras[module_name].lora_a = tensor.to(device=device, loras[module_name].lora_a = tensor.to(device=device,
dtype=dtype).t() dtype=dtype).t()
if pin_memory: if pin_memory:
...@@ -215,7 +224,7 @@ class LoRAModel(AdapterModel): ...@@ -215,7 +224,7 @@ class LoRAModel(AdapterModel):
with safetensors.safe_open(lora_tensor_path, with safetensors.safe_open(lora_tensor_path,
framework="pt") as f: # type: ignore framework="pt") as f: # type: ignore
for lora_module in f.keys(): # noqa for lora_module in f.keys(): # noqa
module_name, _ = parse_fine_tuned_lora_name(lora_module) module_name, _, _ = parse_fine_tuned_lora_name(lora_module)
part_name = module_name.split(".")[-1] part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules: if part_name not in expected_lora_modules:
unexpected_modules.append(module_name) unexpected_modules.append(module_name)
...@@ -386,8 +395,19 @@ class LoRAModelManager(AdapterModelManager): ...@@ -386,8 +395,19 @@ class LoRAModelManager(AdapterModelManager):
module_lora = lora_model.get_lora(module_name) module_lora = lora_model.get_lora(module_name)
if module_lora: if module_lora:
module_lora.optimize() module_lora.optimize()
# Bias is not explicitly enabled with the flag enable_lora_bias.
bias = module_lora.bias
if ((torch.is_tensor(bias) or
(isinstance(bias, Sequence) and any(b is not None
for b in bias)))
and not self.lora_config.bias_enabled):
module_lora.bias = None
raise ValueError(
f"Adapter bias cannot be used for {module_name}"
" without --enable-lora-bias.")
module.set_lora(index, module_lora.lora_a, module_lora.lora_b, module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
module_lora.embeddings_tensor) module_lora.embeddings_tensor,
module_lora.bias)
else: else:
module.reset_lora(index) module.reset_lora(index)
return True return True
...@@ -509,6 +529,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -509,6 +529,7 @@ class LoRAModelManager(AdapterModelManager):
"""Create zero-initialized LoRAModel for warmup.""" """Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor) model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules(): for module_name, module in self.model.named_modules():
bias_enabled = self.lora_config.bias_enabled
if (not self._match_target_modules(module_name) if (not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA) or not isinstance(module, BaseLayerWithLoRA)
or isinstance(module, LinearScalingRotaryEmbeddingWithLora) or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
...@@ -536,7 +557,8 @@ class LoRAModelManager(AdapterModelManager): ...@@ -536,7 +557,8 @@ class LoRAModelManager(AdapterModelManager):
rank, rank,
module.lora_a_stacked.dtype, module.lora_a_stacked.dtype,
"cpu", "cpu",
embeddings_tensor_dim=embeddings_tensor_dim) embeddings_tensor_dim=embeddings_tensor_dim,
bias_enabled=bias_enabled)
else: else:
lora = LoRALayerWeights.create_dummy_lora_weights( lora = LoRALayerWeights.create_dummy_lora_weights(
module_name, module_name,
...@@ -545,6 +567,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -545,6 +567,7 @@ class LoRAModelManager(AdapterModelManager):
rank, rank,
module.lora_a_stacked.dtype, module.lora_a_stacked.dtype,
"cpu", "cpu",
bias_enabled=bias_enabled,
) )
lora.optimize() lora.optimize()
else: else:
...@@ -559,6 +582,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -559,6 +582,7 @@ class LoRAModelManager(AdapterModelManager):
rank, rank,
module.lora_a_stacked[i].dtype, module.lora_a_stacked[i].dtype,
"cpu", "cpu",
bias_enabled=bias_enabled,
) )
lora.optimize() lora.optimize()
subloras.append(lora) subloras.append(lora)
......
...@@ -91,7 +91,7 @@ def replace_submodule(model: nn.Module, module_name: str, ...@@ -91,7 +91,7 @@ def replace_submodule(model: nn.Module, module_name: str,
return new_module return new_module
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool, bool]:
"""Parse the name of lora weights. """Parse the name of lora weights.
args: args:
...@@ -101,15 +101,18 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: ...@@ -101,15 +101,18 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
Tuple(module_name, is_lora_a): Tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1, module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b. is_lora_a whether the tensor is lora_a or lora_b.
is_bias whether the tensor is lora bias.
""" """
parts = name.split(".") parts = name.split(".")
if parts[-1] == "weight" and (parts[-2] == "lora_A"
or parts[-2] == "lora_B"):
return ".".join(parts[2:-2]), parts[-2] == "lora_A", False
if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model": if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
if parts[-1] == "weight": return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A", False
if parts[-2] == "lora_A" or parts[-2] == "lora_B":
return ".".join(parts[2:-2]), parts[-2] == "lora_A" if parts[-1] == "bias":
elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": return ".".join(parts[2:-2]), False, True
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
raise ValueError(f"{name} is unsupported LoRA weight") raise ValueError(f"{name} is unsupported LoRA weight")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment