Unverified Commit 8a06428c authored by Umesh's avatar Umesh Committed by GitHub
Browse files

[LoRA] Adds support for bias in LoRA (#5733)


Signed-off-by: default avatarUmesh Deshpande <udeshpa@us.ibm.com>
Co-authored-by: default avatarUmesh Deshpande <udeshpa@us.ibm.com>
parent b41fb9d3
...@@ -152,6 +152,11 @@ def sql_lora_files(sql_lora_huggingface_id): ...@@ -152,6 +152,11 @@ def sql_lora_files(sql_lora_huggingface_id):
return snapshot_download(repo_id=sql_lora_huggingface_id) return snapshot_download(repo_id=sql_lora_huggingface_id)
@pytest.fixture(scope="session")
def lora_bias_files():
return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def mixtral_lora_files(): def mixtral_lora_files():
# Note: this module has incorrect adapter_config.json to test # Note: this module has incorrect adapter_config.json to test
......
from typing import List
import pytest
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "ibm-granite/granite-3b-code-base"
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256,
stop=["[/assistant]"])
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
generated_texts: List[str] = []
for output in outputs:
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
return generated_texts
@pytest.mark.parametrize("lora_bias", [True])
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_lora_rank=8,
max_loras=1,
enable_lora_bias=lora_bias,
tensor_parallel_size=1,
fully_sharded_loras=fully_sharded)
print("lora adapter created")
output1 = do_sample(llm, lora_bias_files, lora_id=0)
print("lora")
output2 = do_sample(llm, lora_bias_files, lora_id=1)
if lora_bias:
assert output1 != output2
else:
assert output1 == output2
...@@ -12,36 +12,40 @@ from vllm.utils import LRUCache ...@@ -12,36 +12,40 @@ from vllm.utils import LRUCache
def test_parse_fine_tuned_lora_name_valid(): def test_parse_fine_tuned_lora_name_valid():
fixture = { fixture = {
("base_model.model.lm_head.lora_A.weight", "lm_head", True), ("base_model.model.lm_head.lora_A.weight", "lm_head", True, False),
("base_model.model.lm_head.lora_B.weight", "lm_head", False), ("base_model.model.lm_head.lora_B.weight", "lm_head", False, False),
( (
"base_model.model.model.embed_tokens.lora_embedding_A", "base_model.model.model.embed_tokens.lora_embedding_A",
"model.embed_tokens", "model.embed_tokens",
True, True,
False,
), ),
( (
"base_model.model.model.embed_tokens.lora_embedding_B", "base_model.model.model.embed_tokens.lora_embedding_B",
"model.embed_tokens", "model.embed_tokens",
False, False,
False,
), ),
( (
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"model.layers.9.mlp.down_proj", "model.layers.9.mlp.down_proj",
True, True,
False,
), ),
( (
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"model.layers.9.mlp.down_proj", "model.layers.9.mlp.down_proj",
False, False,
False,
), ),
} }
for name, module_name, is_lora_a in fixture: for name, module_name, is_lora_a, is_bias in fixture:
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name) assert (module_name, is_lora_a,
is_bias) == parse_fine_tuned_lora_name(name)
def test_parse_fine_tuned_lora_name_invalid(): def test_parse_fine_tuned_lora_name_invalid():
fixture = { fixture = {
"weight",
"base_model.weight", "base_model.weight",
"base_model.model.weight", "base_model.model.weight",
} }
......
...@@ -1687,6 +1687,7 @@ class LoRAConfig: ...@@ -1687,6 +1687,7 @@ class LoRAConfig:
# This is a constant. # This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256 lora_vocab_padding_size: ClassVar[int] = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None long_lora_scaling_factors: Optional[Tuple[float]] = None
bias_enabled: bool = False
def __post_init__(self): def __post_init__(self):
# Setting the maximum rank to 256 should be able to satisfy the vast # Setting the maximum rank to 256 should be able to satisfy the vast
......
...@@ -143,6 +143,7 @@ class EngineArgs: ...@@ -143,6 +143,7 @@ class EngineArgs:
limit_mm_per_prompt: Optional[Mapping[str, int]] = None limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None
enable_lora: bool = False enable_lora: bool = False
enable_lora_bias: bool = False
max_loras: int = 1 max_loras: int = 1
max_lora_rank: int = 16 max_lora_rank: int = 16
enable_prompt_adapter: bool = False enable_prompt_adapter: bool = False
...@@ -584,6 +585,9 @@ class EngineArgs: ...@@ -584,6 +585,9 @@ class EngineArgs:
parser.add_argument('--enable-lora', parser.add_argument('--enable-lora',
action='store_true', action='store_true',
help='If True, enable handling of LoRA adapters.') help='If True, enable handling of LoRA adapters.')
parser.add_argument('--enable-lora-bias',
action='store_true',
help='If True, enable bias for LoRA adapters.')
parser.add_argument('--max-loras', parser.add_argument('--max-loras',
type=int, type=int,
default=EngineArgs.max_loras, default=EngineArgs.max_loras,
...@@ -1148,6 +1152,7 @@ class EngineArgs: ...@@ -1148,6 +1152,7 @@ class EngineArgs:
and parallel_config.use_ray), and parallel_config.use_ray),
policy=self.scheduling_policy) policy=self.scheduling_policy)
lora_config = LoRAConfig( lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras, max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras, fully_sharded_loras=self.fully_sharded_loras,
......
...@@ -70,6 +70,14 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): ...@@ -70,6 +70,14 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
self.lora_b_stacked, self.lora_b_stacked,
add_input=True) add_input=True)
# now have column partitioned output # now have column partitioned output
if self.bias_stacked is not None:
self.bias_stacked = self.bias_stacked.view(
-1, self.bias_stacked.shape[-1])
self.bias_stacked = self.bias_stacked[
self.punica_wrapper.token_lora_indices]
output += self.bias_stacked
output = output.view(*out_orig_shape) output = output.view(*out_orig_shape)
return output return output
...@@ -121,6 +129,15 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora): ...@@ -121,6 +129,15 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
left_offset = 0 left_offset = 0
for idx in range(n): for idx in range(n):
shard_size = layer.lora_b_stacked[idx].shape[2] shard_size = layer.lora_b_stacked[idx].shape[2]
if layer.bias_stacked is not None:
bias = layer.bias_stacked[idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[layer.punica_wrapper.token_lora_indices]
bias[layer.punica_wrapper.token_lora_indices == -1] = 0
output[:, left_offset:left_offset + shard_size] += bias
layer.punica_wrapper.add_expand_slice( layer.punica_wrapper.add_expand_slice(
output, output,
buffers[idx], buffers[idx],
...@@ -295,6 +312,15 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): ...@@ -295,6 +312,15 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
lora_b = lora_b[:, start_idx:end_idx] lora_b = lora_b[:, start_idx:end_idx]
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None:
return bias
shard_size = self.bias_stacked.shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias
def apply(self, x: torch.Tensor) -> torch.Tensor: def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x) output = self.base_layer.quant_method.apply(self.base_layer, x)
...@@ -318,6 +344,13 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): ...@@ -318,6 +344,13 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
# reduced before being used # reduced before being used
shard_size = self.lora_b_stacked.shape[2] shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size start_idx = self.tp_rank * shard_size
if self.bias_stacked is not None:
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
bias = bias[self.punica_wrapper.token_lora_indices]
bias[self.punica_wrapper.token_lora_indices == -1] = 0
output += bias
self.punica_wrapper.add_expand_slice(output, buffer, self.punica_wrapper.add_expand_slice(output, buffer,
self.lora_b_stacked, start_idx, self.lora_b_stacked, start_idx,
shard_size) shard_size)
......
This diff is collapsed.
...@@ -17,6 +17,7 @@ class LoRALayerWeights: ...@@ -17,6 +17,7 @@ class LoRALayerWeights:
lora_alpha: int, lora_alpha: int,
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
bias: Optional[torch.Tensor] = None,
embeddings_tensor: Optional[torch.Tensor] = None, embeddings_tensor: Optional[torch.Tensor] = None,
scaling: Optional[float] = None, scaling: Optional[float] = None,
) -> None: ) -> None:
...@@ -25,6 +26,7 @@ class LoRALayerWeights: ...@@ -25,6 +26,7 @@ class LoRALayerWeights:
self.lora_alpha = lora_alpha self.lora_alpha = lora_alpha
self.lora_a = lora_a self.lora_a = lora_a
self.lora_b = lora_b self.lora_b = lora_b
self.bias = bias
self.embeddings_tensor = embeddings_tensor self.embeddings_tensor = embeddings_tensor
if scaling is None: if scaling is None:
...@@ -66,7 +68,8 @@ class LoRALayerWeights: ...@@ -66,7 +68,8 @@ class LoRALayerWeights:
rank: int, rank: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.types.Device, device: torch.types.Device,
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": embeddings_tensor_dim: Optional[int] = None,
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available() pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank], lora_a = torch.zeros([input_dim, rank],
dtype=dtype, dtype=dtype,
...@@ -76,6 +79,14 @@ class LoRALayerWeights: ...@@ -76,6 +79,14 @@ class LoRALayerWeights:
dtype=dtype, dtype=dtype,
device=device, device=device,
pin_memory=pin_memory) pin_memory=pin_memory)
if bias_enabled:
bias = torch.zeros([output_dim],
dtype=dtype,
device=device,
pin_memory=pin_memory)
else:
bias = None
embeddings_tensor = torch.rand( embeddings_tensor = torch.rand(
10, 10,
embeddings_tensor_dim, embeddings_tensor_dim,
...@@ -88,6 +99,7 @@ class LoRALayerWeights: ...@@ -88,6 +99,7 @@ class LoRALayerWeights:
lora_alpha=1, lora_alpha=1,
lora_a=lora_a, lora_a=lora_a,
lora_b=lora_b, lora_b=lora_b,
bias=bias,
embeddings_tensor=embeddings_tensor, embeddings_tensor=embeddings_tensor,
) )
...@@ -102,6 +114,7 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -102,6 +114,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alphas: List[Optional[int]], lora_alphas: List[Optional[int]],
lora_a: List[Optional[torch.Tensor]], lora_a: List[Optional[torch.Tensor]],
lora_b: List[Optional[torch.Tensor]], lora_b: List[Optional[torch.Tensor]],
bias: Optional[List[Optional[torch.Tensor]]] = None,
scaling: Optional[List[float]] = None, scaling: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__( super().__init__(
...@@ -110,6 +123,7 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -110,6 +123,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alpha=0, lora_alpha=0,
lora_a=lora_a, lora_a=lora_a,
lora_b=lora_b, lora_b=lora_b,
bias=bias,
scaling=scaling, # type: ignore scaling=scaling, # type: ignore
embeddings_tensor=None, embeddings_tensor=None,
) )
...@@ -141,6 +155,7 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -141,6 +155,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
[lora.lora_alpha if lora is not None else None for lora in loras], [lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras], [lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras], [lora.lora_b if lora is not None else None for lora in loras],
[lora.bias if lora is not None else None for lora in loras],
scaling=[ scaling=[
1 if lora is not None else None # type: ignore 1 if lora is not None else None # type: ignore
for lora in loras for lora in loras
......
...@@ -4,7 +4,7 @@ import math ...@@ -4,7 +4,7 @@ import math
import os import os
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Type from typing import Any, Callable, Dict, List, Optional, Sequence, Type
import safetensors.torch import safetensors.torch
import torch import torch
...@@ -119,7 +119,8 @@ class LoRAModel(AdapterModel): ...@@ -119,7 +119,8 @@ class LoRAModel(AdapterModel):
pin_memory = str(device) == "cpu" and is_pin_memory_available() pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: Dict[str, LoRALayerWeights] = {} loras: Dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items(): for tensor_name, tensor in tensors.items():
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
tensor_name)
if module_name not in loras: if module_name not in loras:
lora_embeddings_tensor = None lora_embeddings_tensor = None
if embeddings: if embeddings:
...@@ -136,8 +137,16 @@ class LoRAModel(AdapterModel): ...@@ -136,8 +137,16 @@ class LoRAModel(AdapterModel):
lora_embeddings_tensor.pin_memory()) lora_embeddings_tensor.pin_memory())
loras[module_name] = LoRALayerWeights(module_name, rank, loras[module_name] = LoRALayerWeights(module_name, rank,
lora_alpha, None, None, lora_alpha, None, None,
None,
lora_embeddings_tensor) lora_embeddings_tensor)
if is_lora_a: if is_bias:
loras[module_name].bias = tensor.to(device=device,
dtype=dtype).t()
bias = tensor.to(device=device, dtype=dtype).t()
if pin_memory:
bias = bias.pin_memory()
loras[module_name].bias = bias
elif is_lora_a:
loras[module_name].lora_a = tensor.to(device=device, loras[module_name].lora_a = tensor.to(device=device,
dtype=dtype).t() dtype=dtype).t()
if pin_memory: if pin_memory:
...@@ -215,7 +224,7 @@ class LoRAModel(AdapterModel): ...@@ -215,7 +224,7 @@ class LoRAModel(AdapterModel):
with safetensors.safe_open(lora_tensor_path, with safetensors.safe_open(lora_tensor_path,
framework="pt") as f: # type: ignore framework="pt") as f: # type: ignore
for lora_module in f.keys(): # noqa for lora_module in f.keys(): # noqa
module_name, _ = parse_fine_tuned_lora_name(lora_module) module_name, _, _ = parse_fine_tuned_lora_name(lora_module)
part_name = module_name.split(".")[-1] part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules: if part_name not in expected_lora_modules:
unexpected_modules.append(module_name) unexpected_modules.append(module_name)
...@@ -386,8 +395,19 @@ class LoRAModelManager(AdapterModelManager): ...@@ -386,8 +395,19 @@ class LoRAModelManager(AdapterModelManager):
module_lora = lora_model.get_lora(module_name) module_lora = lora_model.get_lora(module_name)
if module_lora: if module_lora:
module_lora.optimize() module_lora.optimize()
# Bias is not explicitly enabled with the flag enable_lora_bias.
bias = module_lora.bias
if ((torch.is_tensor(bias) or
(isinstance(bias, Sequence) and any(b is not None
for b in bias)))
and not self.lora_config.bias_enabled):
module_lora.bias = None
raise ValueError(
f"Adapter bias cannot be used for {module_name}"
" without --enable-lora-bias.")
module.set_lora(index, module_lora.lora_a, module_lora.lora_b, module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
module_lora.embeddings_tensor) module_lora.embeddings_tensor,
module_lora.bias)
else: else:
module.reset_lora(index) module.reset_lora(index)
return True return True
...@@ -509,6 +529,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -509,6 +529,7 @@ class LoRAModelManager(AdapterModelManager):
"""Create zero-initialized LoRAModel for warmup.""" """Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor) model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules(): for module_name, module in self.model.named_modules():
bias_enabled = self.lora_config.bias_enabled
if (not self._match_target_modules(module_name) if (not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA) or not isinstance(module, BaseLayerWithLoRA)
or isinstance(module, LinearScalingRotaryEmbeddingWithLora) or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
...@@ -536,7 +557,8 @@ class LoRAModelManager(AdapterModelManager): ...@@ -536,7 +557,8 @@ class LoRAModelManager(AdapterModelManager):
rank, rank,
module.lora_a_stacked.dtype, module.lora_a_stacked.dtype,
"cpu", "cpu",
embeddings_tensor_dim=embeddings_tensor_dim) embeddings_tensor_dim=embeddings_tensor_dim,
bias_enabled=bias_enabled)
else: else:
lora = LoRALayerWeights.create_dummy_lora_weights( lora = LoRALayerWeights.create_dummy_lora_weights(
module_name, module_name,
...@@ -545,6 +567,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -545,6 +567,7 @@ class LoRAModelManager(AdapterModelManager):
rank, rank,
module.lora_a_stacked.dtype, module.lora_a_stacked.dtype,
"cpu", "cpu",
bias_enabled=bias_enabled,
) )
lora.optimize() lora.optimize()
else: else:
...@@ -559,6 +582,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -559,6 +582,7 @@ class LoRAModelManager(AdapterModelManager):
rank, rank,
module.lora_a_stacked[i].dtype, module.lora_a_stacked[i].dtype,
"cpu", "cpu",
bias_enabled=bias_enabled,
) )
lora.optimize() lora.optimize()
subloras.append(lora) subloras.append(lora)
......
...@@ -91,7 +91,7 @@ def replace_submodule(model: nn.Module, module_name: str, ...@@ -91,7 +91,7 @@ def replace_submodule(model: nn.Module, module_name: str,
return new_module return new_module
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool, bool]:
"""Parse the name of lora weights. """Parse the name of lora weights.
args: args:
...@@ -101,15 +101,18 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: ...@@ -101,15 +101,18 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
Tuple(module_name, is_lora_a): Tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1, module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b. is_lora_a whether the tensor is lora_a or lora_b.
is_bias whether the tensor is lora bias.
""" """
parts = name.split(".") parts = name.split(".")
if parts[-1] == "weight" and (parts[-2] == "lora_A"
or parts[-2] == "lora_B"):
return ".".join(parts[2:-2]), parts[-2] == "lora_A", False
if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model": if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
if parts[-1] == "weight": return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A", False
if parts[-2] == "lora_A" or parts[-2] == "lora_B":
return ".".join(parts[2:-2]), parts[-2] == "lora_A" if parts[-1] == "bias":
elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": return ".".join(parts[2:-2]), False, True
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
raise ValueError(f"{name} is unsupported LoRA weight") raise ValueError(f"{name} is unsupported LoRA weight")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment