"dgl_sparse/vscode:/vscode.git/clone" did not exist on "5ffd2a022c3bc2b37a47d460d5f32d034f61d81e"
Unverified Commit a8e47978 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[lora] adapt new LoRA config injection method (#11999)

* use state dict when setting up LoRA.

* up

* up

* up

* comment

* up

* up
parent 50e18ee6
...@@ -116,7 +116,7 @@ _deps = [ ...@@ -116,7 +116,7 @@ _deps = [
"librosa", "librosa",
"numpy", "numpy",
"parameterized", "parameterized",
"peft>=0.15.0", "peft>=0.17.0",
"protobuf>=3.20.3,<4", "protobuf>=3.20.3,<4",
"pytest", "pytest",
"pytest-timeout", "pytest-timeout",
......
...@@ -23,7 +23,7 @@ deps = { ...@@ -23,7 +23,7 @@ deps = {
"librosa": "librosa", "librosa": "librosa",
"numpy": "numpy", "numpy": "numpy",
"parameterized": "parameterized", "parameterized": "parameterized",
"peft": "peft>=0.15.0", "peft": "peft>=0.17.0",
"protobuf": "protobuf>=3.20.3,<4", "protobuf": "protobuf>=3.20.3,<4",
"pytest": "pytest", "pytest": "pytest",
"pytest-timeout": "pytest-timeout", "pytest-timeout": "pytest-timeout",
......
...@@ -320,7 +320,9 @@ class PeftAdapterMixin: ...@@ -320,7 +320,9 @@ class PeftAdapterMixin:
# it to None # it to None
incompatible_keys = None incompatible_keys = None
else: else:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) inject_adapter_in_model(
lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs
)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
if self._prepare_lora_hotswap_kwargs is not None: if self._prepare_lora_hotswap_kwargs is not None:
......
...@@ -197,20 +197,6 @@ def get_peft_kwargs( ...@@ -197,20 +197,6 @@ def get_peft_kwargs(
"lora_bias": lora_bias, "lora_bias": lora_bias,
} }
# Example: try load FusionX LoRA into Wan VACE
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
if exclude_modules:
if not is_peft_version(">=", "0.14.0"):
msg = """
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
https://github.com/huggingface/diffusers/issues/new
"""
logger.debug(msg)
else:
lora_config_kwargs.update({"exclude_modules": exclude_modules})
return lora_config_kwargs return lora_config_kwargs
...@@ -388,27 +374,3 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): ...@@ -388,27 +374,3 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
if warn_msg: if warn_msg:
logger.warning(warn_msg) logger.warning(warn_msg)
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
"""
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
doesn't exist in `peft_state_dict`.
"""
if model_state_dict is None:
return
all_modules = set()
string_to_replace = f"{adapter_name}." if adapter_name else ""
for name in model_state_dict.keys():
if string_to_replace:
name = name.replace(string_to_replace, "")
if "." in name:
module_name = name.rsplit(".", 1)[0]
all_modules.add(module_name)
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
exclude_modules = list(all_modules - target_modules_set)
return exclude_modules
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import inspect import inspect
import os import os
import re import re
...@@ -292,20 +291,6 @@ class PeftLoraLoaderMixinTests: ...@@ -292,20 +291,6 @@ class PeftLoraLoaderMixinTests:
return modules_to_save return modules_to_save
def _get_exclude_modules(self, pipe):
from diffusers.utils.peft_utils import _derive_exclude_modules
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
pipe.unload_lora_weights()
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
exclude_modules = _derive_exclude_modules(
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
)
return exclude_modules
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
if text_lora_config is not None: if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
...@@ -2342,58 +2327,6 @@ class PeftLoraLoaderMixinTests: ...@@ -2342,58 +2327,6 @@ class PeftLoraLoaderMixinTests:
) )
_ = pipe(**inputs, generator=torch.manual_seed(0))[0] _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
@require_peft_version_greater("0.13.2")
def test_lora_exclude_modules(self):
"""
Test to check if `exclude_modules` works or not. It works in the following way:
we first create a pipeline and insert LoRA config into it. We then derive a `set`
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
state dict.
We then create a new LoRA config to include the `exclude_modules` and perform tests.
"""
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
# only supported for `denoiser` now
pipe_cp = copy.deepcopy(pipe)
pipe_cp, _ = self.add_adapters_to_pipeline(
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
pipe_cp.to("cpu")
del pipe_cp
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
pipe.load_lora_weights(tmpdir)
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
"LoRA should change outputs.",
)
self.assertTrue(
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
"Lora outputs should match.",
)
def test_inference_load_delete_load_adapters(self): def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes: for scheduler_cls in self.scheduler_classes:
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from diffusers import FluxTransformer2DModel from diffusers import FluxTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device from diffusers.utils.testing_utils import enable_full_determinism, is_peft_available, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
...@@ -172,6 +172,35 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): ...@@ -172,6 +172,35 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
expected_set = {"FluxTransformer2DModel"} expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# The test exists for cases like
# https://github.com/huggingface/diffusers/issues/11874
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_exclude_modules(self):
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
lora_rank = 4
target_module = "single_transformer_blocks.0.proj_out"
adapter_name = "foo"
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
state_dict = model.state_dict()
target_mod_shape = state_dict[f"{target_module}.weight"].shape
lora_state_dict = {
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
}
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
config = LoraConfig(
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
)
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel model_class = FluxTransformer2DModel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment