Unverified Commit 41e4779d authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] fix: lora loading when using with a device_mapped model. (#9449)



* fix: lora loading when using with a device_mapped model.

* better attibutung

* empty
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* minors

* better error messages.

* fix-copies

* add: tests, docs.

* add hardware note.

* quality

* Update docs/source/en/training/distributed_inference.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* fixes

* skip properly.

* fixes

---------
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent ff182ad6
......@@ -576,6 +576,15 @@ class UniDiffuserPipelineFastTests(
expected_text_prefix = '" This This'
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
def test_calling_mco_raises_error_device_mapped_components(self):
super().test_calling_mco_raises_error_device_mapped_components(safe_serialization=False)
def test_calling_to_raises_error_device_mapped_components(self):
super().test_calling_to_raises_error_device_mapped_components(safe_serialization=False)
def test_calling_sco_raises_error_device_mapped_components(self):
super().test_calling_sco_raises_error_device_mapped_components(safe_serialization=False)
@nightly
@require_torch_gpu
......
......@@ -237,3 +237,15 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase
def test_callback_cfg(self):
pass
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment