Unverified Commit 82cabf53 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc] Delete unused LoRA modules (#13151)

parent 314cfade
...@@ -606,20 +606,26 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): ...@@ -606,20 +606,26 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
assert isinstance(model.get_submodule("gate_up_proj"), assert isinstance(model.get_submodule("gate_up_proj"),
MergedColumnParallelLinearWithLoRA) MergedColumnParallelLinearWithLoRA)
# Verify packed lora is correct
model_lora_clone = model_lora.clone(1)
model_lora_clone1 = model_lora1.clone(1)
assert manager.add_adapter(model_lora) assert manager.add_adapter(model_lora)
assert manager.add_adapter(model_lora1) assert manager.add_adapter(model_lora1)
assert model_lora.get_lora("gate_proj") is None
assert model_lora.get_lora("up_proj") is None
assert model_lora1.get_lora("up_proj") is None
packed_lora = model_lora.get_lora("gate_up_proj") packed_lora = model_lora.get_lora("gate_up_proj")
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights) assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
torch.testing.assert_close(packed_lora.lora_a[0], torch.testing.assert_close(packed_lora.lora_a[0],
model_lora.get_lora("gate_proj").lora_a) model_lora_clone.get_lora("gate_proj").lora_a)
torch.testing.assert_close(packed_lora.lora_b[0], torch.testing.assert_close(packed_lora.lora_b[0],
model_lora.get_lora("gate_proj").lora_b) model_lora_clone.get_lora("gate_proj").lora_b)
torch.testing.assert_close(packed_lora.lora_a[1], torch.testing.assert_close(packed_lora.lora_a[1],
model_lora.get_lora("up_proj").lora_a) model_lora_clone.get_lora("up_proj").lora_a)
torch.testing.assert_close(packed_lora.lora_b[1], torch.testing.assert_close(packed_lora.lora_b[1],
model_lora.get_lora("up_proj").lora_b) model_lora_clone.get_lora("up_proj").lora_b)
packed_lora1 = model_lora1.get_lora("gate_up_proj") packed_lora1 = model_lora1.get_lora("gate_up_proj")
assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights) assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)
...@@ -627,6 +633,6 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): ...@@ -627,6 +633,6 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
assert packed_lora1.lora_a[0] is None assert packed_lora1.lora_a[0] is None
assert packed_lora1.lora_b[0] is None assert packed_lora1.lora_b[0] is None
torch.testing.assert_close(packed_lora1.lora_a[1], torch.testing.assert_close(packed_lora1.lora_a[1],
model_lora1.get_lora("up_proj").lora_a) model_lora_clone1.get_lora("up_proj").lora_a)
torch.testing.assert_close(packed_lora1.lora_b[1], torch.testing.assert_close(packed_lora1.lora_b[1],
model_lora1.get_lora("up_proj").lora_b) model_lora_clone1.get_lora("up_proj").lora_b)
...@@ -5,7 +5,8 @@ import math ...@@ -5,7 +5,8 @@ import math
import os import os
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type,
Union)
import safetensors.torch import safetensors.torch
import torch import torch
...@@ -619,12 +620,14 @@ class LoRAModelManager(AdapterModelManager): ...@@ -619,12 +620,14 @@ class LoRAModelManager(AdapterModelManager):
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items(): for module_name, new_module_names in self.packed_modules.items():
replacement_loras: List[Optional[LoRALayerWeights]] = [] replacement_loras: List[Optional[LoRALayerWeights]] = []
replaced_module: Set[str] = set()
has_replacement = False has_replacement = False
for r in new_module_names: for r in new_module_names:
lora = lora_model.get_lora(r) lora = lora_model.get_lora(r)
replacement_loras.append(lora) replacement_loras.append(lora)
if lora: if lora:
has_replacement = True has_replacement = True
replaced_module.add(r)
if not has_replacement: if not has_replacement:
continue continue
for i in range(len(replacement_loras)): for i in range(len(replacement_loras)):
...@@ -633,6 +636,9 @@ class LoRAModelManager(AdapterModelManager): ...@@ -633,6 +636,9 @@ class LoRAModelManager(AdapterModelManager):
replacement_loras[i] = None replacement_loras[i] = None
lora_model.loras[module_name] = PackedLoRALayerWeights.pack( lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras) replacement_loras)
# Remove the modules that have been replaced.
for module in replaced_module:
lora_model.loras.pop(module, None)
def deactivate_adapter(self, adapter_id: int) -> bool: def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters, return deactivate_adapter(adapter_id, self._active_adapters,
......
...@@ -147,7 +147,7 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -147,7 +147,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
dtype=torch.long, dtype=torch.long,
device=device) device=device)
# 5 is the number of indicies tensors. # 5 is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded, # base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices,long_lora_indices # embeddings_indices,long_lora_indices
self.indices_len: List[Optional[int]] = [None] * 5 self.indices_len: List[Optional[int]] = [None] * 5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment