"src/diffusers/pipelines/flux/pipeline_flux.py" did not exist on "9b8c8605d14b4543db8e0169d4aac97ad95a2a27"
Commit 3996d421 authored by sandy's avatar sandy Committed by GitHub
Browse files

Automatically decide whether to do cfg based on guide_scale

parent 11303152
......@@ -109,6 +109,9 @@ class WanModel:
self.post_weight.to_cuda()
self.transformer_weights.to_cuda()
def do_classifier_free_guidance(self) -> bool:
return self.config.sample_guide_scale > 1
@torch.no_grad()
def infer(self, inputs):
if self.config["cpu_offload"]:
......@@ -125,7 +128,7 @@ class WanModel:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_cond
if self.config["enable_cfg"]:
if self.do_classifier_free_guidance():
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment