Unverified Commit 82cabf53 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc] Delete unused LoRA modules (#13151)

parent 314cfade
......@@ -606,20 +606,26 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
assert isinstance(model.get_submodule("gate_up_proj"),
MergedColumnParallelLinearWithLoRA)
# Verify packed lora is correct
model_lora_clone = model_lora.clone(1)
model_lora_clone1 = model_lora1.clone(1)
assert manager.add_adapter(model_lora)
assert manager.add_adapter(model_lora1)
assert model_lora.get_lora("gate_proj") is None
assert model_lora.get_lora("up_proj") is None
assert model_lora1.get_lora("up_proj") is None
packed_lora = model_lora.get_lora("gate_up_proj")
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
torch.testing.assert_close(packed_lora.lora_a[0],
model_lora.get_lora("gate_proj").lora_a)
model_lora_clone.get_lora("gate_proj").lora_a)
torch.testing.assert_close(packed_lora.lora_b[0],
model_lora.get_lora("gate_proj").lora_b)
model_lora_clone.get_lora("gate_proj").lora_b)
torch.testing.assert_close(packed_lora.lora_a[1],
model_lora.get_lora("up_proj").lora_a)
model_lora_clone.get_lora("up_proj").lora_a)
torch.testing.assert_close(packed_lora.lora_b[1],
model_lora.get_lora("up_proj").lora_b)
model_lora_clone.get_lora("up_proj").lora_b)
packed_lora1 = model_lora1.get_lora("gate_up_proj")
assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)
......@@ -627,6 +633,6 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
assert packed_lora1.lora_a[0] is None
assert packed_lora1.lora_b[0] is None
torch.testing.assert_close(packed_lora1.lora_a[1],
model_lora1.get_lora("up_proj").lora_a)
model_lora_clone1.get_lora("up_proj").lora_a)
torch.testing.assert_close(packed_lora1.lora_b[1],
model_lora1.get_lora("up_proj").lora_b)
model_lora_clone1.get_lora("up_proj").lora_b)
......@@ -5,7 +5,8 @@ import math
import os
import re
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type,
Union)
import safetensors.torch
import torch
......@@ -619,12 +620,14 @@ class LoRAModelManager(AdapterModelManager):
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items():
replacement_loras: List[Optional[LoRALayerWeights]] = []
replaced_module: Set[str] = set()
has_replacement = False
for r in new_module_names:
lora = lora_model.get_lora(r)
replacement_loras.append(lora)
if lora:
has_replacement = True
replaced_module.add(r)
if not has_replacement:
continue
for i in range(len(replacement_loras)):
......@@ -633,6 +636,9 @@ class LoRAModelManager(AdapterModelManager):
replacement_loras[i] = None
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras)
# Remove the modules that have been replaced.
for module in replaced_module:
lora_model.loras.pop(module, None)
def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters,
......
......@@ -147,7 +147,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
dtype=torch.long,
device=device)
# 5 is the number of indicies tensors.
# 5 is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices,long_lora_indices
self.indices_len: List[Optional[int]] = [None] * 5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment