Unverified Commit edb8c1bc authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Flux] Improve true cfg condition (#10539)

* improve flux true cfg condition

* add test
parent 0785dba4
...@@ -790,7 +790,10 @@ class FluxPipeline( ...@@ -790,7 +790,10 @@ class FluxPipeline(
lora_scale = ( lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
) )
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
( (
prompt_embeds, prompt_embeds,
pooled_prompt_embeds, pooled_prompt_embeds,
......
...@@ -209,6 +209,17 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapte ...@@ -209,6 +209,17 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapte
output_height, output_width, _ = image.shape output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width) assert (output_height, output_width) == (expected_height, expected_width)
def test_flux_true_cfg(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
inputs.pop("generator")
no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
inputs["negative_prompt"] = "bad quality"
inputs["true_cfg_scale"] = 2.0
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
assert not np.allclose(no_true_cfg_out, true_cfg_out)
@nightly @nightly
@require_big_gpu_with_torch_cuda @require_big_gpu_with_torch_cuda
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment