Unverified Commit d13b0d63 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Flux] add lora integration tests. (#9353)

* add lora integration tests.

* internal note

* add a skip marker.
parent 5d476f57
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import os import os
import sys import sys
import tempfile import tempfile
...@@ -23,7 +24,14 @@ import torch ...@@ -23,7 +24,14 @@ import torch
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, is_peft_available, require_peft_backend, torch_device from diffusers.utils.testing_utils import (
floats_tensor,
is_peft_available,
require_peft_backend,
require_torch_gpu,
slow,
torch_device,
)
if is_peft_available(): if is_peft_available():
...@@ -145,3 +153,89 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -145,3 +153,89 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"Loading from saved checkpoints should give same results.", "Loading from saved checkpoints should give same results.",
) )
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3)) self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
@slow
@require_torch_gpu
@require_peft_backend
@unittest.skip("We cannot run inference on this model with the current CI hardware")
# TODO (DN6, sayakpaul): move these tests to a beefier GPU
class FluxLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on audace."""
num_inference_steps = 10
seed = 0
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_flux_the_last_ben(self):
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.enable_model_cpu_offload()
prompt = "jon snow eating pizza with ketchup"
out = self.pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=4.0,
output_type="np",
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.1719, 0.1719, 0.1699, 0.1719, 0.1719, 0.1738, 0.1641, 0.1621, 0.2090])
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
def test_flux_kohya(self):
self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.enable_model_cpu_offload()
prompt = "The cat with a brain slug earring"
out = self.pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=4.5,
output_type="np",
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484])
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.enable_model_cpu_offload()
prompt = "A blue jay standing on a large basket of rainbow macarons, disney style"
out = self.pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=3.5,
output_type="np",
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.3984, 0.4199, 0.4453, 0.4102, 0.4375, 0.4590, 0.4141, 0.4355, 0.4980])
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment