Unverified Commit 7f3e9b86 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

make flux ready for mellon (#12419)



* make flux ready for mellon

* up

* Apply suggestions from code review
Co-authored-by: default avatarÁlvaro Somoza <asomoza@users.noreply.github.com>

---------
Co-authored-by: default avatarÁlvaro Somoza <asomoza@users.noreply.github.com>
parent ce90f9b2
...@@ -252,11 +252,13 @@ class FluxInputStep(ModularPipelineBlocks): ...@@ -252,11 +252,13 @@ class FluxInputStep(ModularPipelineBlocks):
InputParam( InputParam(
"prompt_embeds", "prompt_embeds",
required=True, required=True,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor, type_hint=torch.Tensor,
description="Pre-generated text embeddings. Can be generated from text_encoder step.", description="Pre-generated text embeddings. Can be generated from text_encoder step.",
), ),
InputParam( InputParam(
"pooled_prompt_embeds", "pooled_prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor, type_hint=torch.Tensor,
description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.", description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
), ),
...@@ -279,11 +281,13 @@ class FluxInputStep(ModularPipelineBlocks): ...@@ -279,11 +281,13 @@ class FluxInputStep(ModularPipelineBlocks):
OutputParam( OutputParam(
"prompt_embeds", "prompt_embeds",
type_hint=torch.Tensor, type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="text embeddings used to guide the image generation", description="text embeddings used to guide the image generation",
), ),
OutputParam( OutputParam(
"pooled_prompt_embeds", "pooled_prompt_embeds",
type_hint=torch.Tensor, type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="pooled text embeddings used to guide the image generation", description="pooled text embeddings used to guide the image generation",
), ),
# TODO: support negative embeddings? # TODO: support negative embeddings?
......
...@@ -181,6 +181,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks): ...@@ -181,6 +181,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
return [ return [
InputParam("prompt"), InputParam("prompt"),
InputParam("prompt_2"), InputParam("prompt_2"),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("joint_attention_kwargs"), InputParam("joint_attention_kwargs"),
] ]
...@@ -189,16 +190,19 @@ class FluxTextEncoderStep(ModularPipelineBlocks): ...@@ -189,16 +190,19 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
return [ return [
OutputParam( OutputParam(
"prompt_embeds", "prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor, type_hint=torch.Tensor,
description="text embeddings used to guide the image generation", description="text embeddings used to guide the image generation",
), ),
OutputParam( OutputParam(
"pooled_prompt_embeds", "pooled_prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor, type_hint=torch.Tensor,
description="pooled text embeddings used to guide the image generation", description="pooled text embeddings used to guide the image generation",
), ),
OutputParam( OutputParam(
"text_ids", "text_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor, type_hint=torch.Tensor,
description="ids from the text sequence for RoPE", description="ids from the text sequence for RoPE",
), ),
...@@ -404,6 +408,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks): ...@@ -404,6 +408,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
pooled_prompt_embeds=None, pooled_prompt_embeds=None,
device=block_state.device, device=block_state.device,
num_images_per_prompt=1, # TODO: hardcoded for now. num_images_per_prompt=1, # TODO: hardcoded for now.
max_sequence_length=block_state.max_sequence_length,
lora_scale=block_state.text_encoder_lora_scale, lora_scale=block_state.text_encoder_lora_scale,
) )
......
...@@ -84,9 +84,9 @@ class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): ...@@ -84,9 +84,9 @@ class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
# before_denoise: all task (text2img, img2img) # before_denoise: all task (text2img, img2img)
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks): class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxBeforeDenoiseStep, FluxImg2ImgBeforeDenoiseStep] block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep]
block_names = ["text2image", "img2img"] block_names = ["img2img", "text2image"]
block_trigger_inputs = [None, "image_latents"] block_trigger_inputs = ["image_latents", None]
@property @property
def description(self): def description(self):
...@@ -124,16 +124,32 @@ class FluxAutoDecodeStep(AutoPipelineBlocks): ...@@ -124,16 +124,32 @@ class FluxAutoDecodeStep(AutoPipelineBlocks):
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`" return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
class FluxCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [FluxInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
block_names = ["input", "before_denoise", "denoise"]
@property
def description(self):
return (
"Core step that performs the denoising process. \n"
+ " - `FluxInputStep` (input) standardizes the inputs for the denoising step.\n"
+ " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ "This step support text-to-image and image-to-image tasks for Flux:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings"
)
# text2image # text2image
class FluxAutoBlocks(SequentialPipelineBlocks): class FluxAutoBlocks(SequentialPipelineBlocks):
block_classes = [ block_classes = [
FluxTextEncoderStep, FluxTextEncoderStep,
FluxAutoVaeEncoderStep, FluxAutoVaeEncoderStep,
FluxAutoBeforeDenoiseStep, FluxCoreDenoiseStep,
FluxAutoDenoiseStep,
FluxAutoDecodeStep, FluxAutoDecodeStep,
] ]
block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decoder"] block_names = ["text_encoder", "image_encoder", "denoise", "decode"]
@property @property
def description(self): def description(self):
...@@ -171,8 +187,7 @@ AUTO_BLOCKS = InsertableDict( ...@@ -171,8 +187,7 @@ AUTO_BLOCKS = InsertableDict(
[ [
("text_encoder", FluxTextEncoderStep), ("text_encoder", FluxTextEncoderStep),
("image_encoder", FluxAutoVaeEncoderStep), ("image_encoder", FluxAutoVaeEncoderStep),
("before_denoise", FluxAutoBeforeDenoiseStep), ("denoise", FluxCoreDenoiseStep),
("denoise", FluxAutoDenoiseStep),
("decode", FluxAutoDecodeStep), ("decode", FluxAutoDecodeStep),
] ]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment