Unverified Commit 17128c42 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] feat: support loading regular Flux LoRAs into Flux Control, and Fill (#10259)



* lora expansion with dummy zeros.

* updates

* fix working 🥳

* working.

* use torch.device meta for state dict expansion.

* tests
Co-authored-by: default avatara-r-r-o-w <contact.aryanvs@gmail.com>

* fixes

* fixes

* switch to debug

* fix

* Apply suggestions from code review
Co-authored-by: default avatarAryan <aryan@huggingface.co>

* fix stuff

* docs

---------
Co-authored-by: default avatara-r-r-o-w <contact.aryanvs@gmail.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent dbc1d505
...@@ -268,6 +268,43 @@ images = pipe( ...@@ -268,6 +268,43 @@ images = pipe(
images[0].save("flux-redux.png") images[0].save("flux-redux.png")
``` ```
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
```py
from diffusers import FluxControlPipeline
from image_gen_aux import DepthPreprocessor
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download
import torch
control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth")
control_pipe.load_lora_weights(
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
control_pipe.enable_model_cpu_offload()
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")
image = control_pipe(
prompt=prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=8,
guidance_scale=10.0,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("output.png")
```
## Running FP16 inference ## Running FP16 inference
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details. Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
......
...@@ -1863,6 +1863,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1863,6 +1863,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
"To get a comprehensive list of parameter names that were modified, enable debug logging." "To get a comprehensive list of parameter names that were modified, enable debug logging."
) )
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
transformer=transformer, lora_state_dict=transformer_lora_state_dict
)
if len(transformer_lora_state_dict) > 0: if len(transformer_lora_state_dict) > 0:
self.load_lora_into_transformer( self.load_lora_into_transformer(
...@@ -2309,16 +2312,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2309,16 +2312,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
# Expand transformer parameter shapes if they don't match lora # Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update = False has_param_with_shape_update = False
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
for name, module in transformer.named_modules(): for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear): if isinstance(module, torch.nn.Linear):
module_weight = module.weight.data module_weight = module.weight.data
module_bias = module.bias.data if module.bias is not None else None module_bias = module.bias.data if module.bias is not None else None
bias = module_bias is not None bias = module_bias is not None
lora_A_weight_name = f"{name}.lora_A.weight" lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
lora_B_weight_name = f"{name}.lora_B.weight" lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
if lora_A_weight_name not in state_dict.keys(): lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
if lora_A_weight_name not in state_dict:
continue continue
in_features = state_dict[lora_A_weight_name].shape[1] in_features = state_dict[lora_A_weight_name].shape[1]
...@@ -2329,56 +2333,105 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2329,56 +2333,105 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
continue continue
module_out_features, module_in_features = module_weight.shape module_out_features, module_in_features = module_weight.shape
if out_features < module_out_features or in_features < module_in_features: debug_message = ""
raise NotImplementedError( if in_features > module_in_features:
f"Only LoRAs with input/output features higher than the current module's input/output features " debug_message += (
f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which " f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
f"are lower than {module_in_features=} and {module_out_features=}. If you require support for " f"checkpoint contains higher number of features than expected. The number of input_features will be "
f"this please open an issue at https://github.com/huggingface/diffusers/issues." f"expanded from {module_in_features} to {in_features}"
) )
if out_features > module_out_features:
debug_message = (
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
f"checkpoint contains higher number of features than expected. The number of input_features will be "
f"expanded from {module_in_features} to {in_features}"
)
if module_out_features != out_features:
debug_message += ( debug_message += (
", and the number of output features will be " ", and the number of output features will be "
f"expanded from {module_out_features} to {out_features}." f"expanded from {module_out_features} to {out_features}."
) )
else: else:
debug_message += "." debug_message += "."
logger.debug(debug_message) if debug_message:
logger.debug(debug_message)
if out_features > module_out_features or in_features > module_in_features:
has_param_with_shape_update = True
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
with torch.device("meta"):
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, dtype=module_weight.dtype
)
# Only weights are expanded and biases are not. This is because only the input dimensions
# are changed while the output dimensions remain the same. The shape of the weight tensor
# is (out_features, in_features), while the shape of bias tensor is (out_features,), which
# explains the reason why only weights are expanded.
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight.shape)
new_weight[slices] = module_weight
tmp_state_dict = {"weight": new_weight}
if module_bias is not None:
tmp_state_dict["bias"] = module_bias
expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
setattr(parent_module, current_module_name, expanded_module)
del tmp_state_dict
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(expanded_module.weight.data.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)
has_param_with_shape_update = True return has_param_with_shape_update
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
# TODO: consider initializing this under meta device for optims. @classmethod
expanded_module = torch.nn.Linear( def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype expanded_module_names = set()
) transformer_state_dict = transformer.state_dict()
# Only weights are expanded and biases are not. prefix = f"{cls.transformer_name}."
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype lora_module_names = [
key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
]
lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
lora_module_names = sorted(set(lora_module_names))
transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
unexpected_modules = set(lora_module_names) - set(transformer_module_names)
if unexpected_modules:
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
for k in lora_module_names:
if k in unexpected_modules:
continue
base_param_name = (
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
)
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
if base_weight_param.shape[1] > lora_A_param.shape[1]:
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
expanded_module_names.add(k)
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
raise NotImplementedError(
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
) )
slices = tuple(slice(0, dim) for dim in module_weight.shape)
new_weight[slices] = module_weight
expanded_module.weight.data.copy_(new_weight)
if module_bias is not None:
expanded_module.bias.data.copy_(module_bias)
setattr(parent_module, current_module_name, expanded_module)
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: if expanded_module_names:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] logger.info(
new_value = int(expanded_module.weight.data.shape[1]) f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
old_value = getattr(transformer.config, attribute_name) )
setattr(transformer.config, attribute_name, new_value)
logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.")
return has_param_with_shape_update return lora_state_dict
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
......
...@@ -340,21 +340,6 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -340,21 +340,6 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
# We should error out because lora input features is less than original. We only
# support expanding the module, not shrinking it
with self.assertRaises(NotImplementedError):
pipe.load_lora_weights(lora_state_dict, "adapter-1")
@require_peft_version_greater("0.13.2") @require_peft_version_greater("0.13.2")
def test_lora_B_bias(self): def test_lora_B_bias(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
...@@ -430,10 +415,10 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -430,10 +415,10 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
def test_lora_expanding_shape_with_normal_lora_raises_error(self): def test_lora_expanding_shape_with_normal_lora(self):
# TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but # This test checks if it works when a lora with expanded shapes (like control loras) but
# another lora with correct shapes is loaded. This is not supported at the moment and should raise an error. # another lora with correct shapes is loaded. The opposite direction isn't supported and is
# When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180 # tested with it.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
# Change the transformer config to mimic a real use case. # Change the transformer config to mimic a real use case.
...@@ -478,27 +463,18 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -478,27 +463,18 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
} }
# The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct with CaptureLogger(logger) as cap_logger:
# input features before expansion. This should raise an error about the weight shapes being incompatible. pipe.load_lora_weights(lora_state_dict, "adapter-2")
self.assertRaisesRegex(
RuntimeError, self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
"size mismatch for x_embedder.lora_A.adapter-2.weight", self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
pipe.load_lora_weights, self.assertTrue(pipe.get_active_adapters() == ["adapter-2"])
lora_state_dict,
"adapter-2",
)
# We should have `adapter-1` as the only adapter.
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
# Check if the output is the same after lora loading error lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
lora_output_after_error = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
self.assertTrue(np.allclose(lora_output, lora_output_after_error, atol=1e-3, rtol=1e-3))
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features. # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
# This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the # This should raise a runtime error on input shapes being incompatible.
# original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora
# weight is compatible with the current model inadequate. This should be addressed when attempting support for
# https://github.com/huggingface/diffusers/issues/10180 (TODO)
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
# Change the transformer config to mimic a real use case. # Change the transformer config to mimic a real use case.
num_channels_without_control = 4 num_channels_without_control = 4
...@@ -521,14 +497,11 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -521,14 +497,11 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
} }
pipe.load_lora_weights(lora_state_dict, "adapter-1")
with CaptureLogger(logger) as cap_logger: self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features) self.assertTrue(pipe.transformer.config.in_channels == in_features)
self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
lora_state_dict = { lora_state_dict = {
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
...@@ -546,6 +519,98 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -546,6 +519,98 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"adapter-2", "adapter-2",
) )
def test_fuse_expanded_lora_with_regular_lora(self):
# This test checks if it works when a lora with expanded shapes (like control loras) but
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
# tested with it.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
components["transformer"] = transformer
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)
out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4
shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
}
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
_, _, inputs = self.get_dummy_inputs(with_generator=False)
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
pipe.load_lora_weights(lora_state_dict, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0])
lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3))
pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"])
lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3))
def test_load_regular_lora(self):
# This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded
# into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those
# transformers include Flux Fill, Flux Control, etc.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4
in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA.
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.INFO)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
@unittest.skip("Not supported in Flux.") @unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment