Unverified Commit f7b79452 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[modular] fix flux modular pipelines for t2i and i2i (#12272)

fix flux modular pipelines for t2i and i2i
parent 43459079
...@@ -454,6 +454,9 @@ class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks): ...@@ -454,6 +454,9 @@ class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
block_state = self.get_block_state(state) block_state = self.get_block_state(state)
block_state.device = components._execution_device block_state.device = components._execution_device
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
scheduler = components.scheduler scheduler = components.scheduler
transformer = components.transformer transformer = components.transformer
batch_size = block_state.batch_size * block_state.num_images_per_prompt batch_size = block_state.batch_size * block_state.num_images_per_prompt
...@@ -659,8 +662,6 @@ class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks): ...@@ -659,8 +662,6 @@ class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state) block_state = self.get_block_state(state)
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.device = components._execution_device block_state.device = components._execution_device
block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this? block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
block_state.num_channels_latents = components.num_channels_latents block_state.num_channels_latents = components.num_channels_latents
......
...@@ -148,8 +148,8 @@ TEXT2IMAGE_BLOCKS = InsertableDict( ...@@ -148,8 +148,8 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
[ [
("text_encoder", FluxTextEncoderStep), ("text_encoder", FluxTextEncoderStep),
("input", FluxInputStep), ("input", FluxInputStep),
("set_timesteps", FluxSetTimestepsStep),
("prepare_latents", FluxPrepareLatentsStep), ("prepare_latents", FluxPrepareLatentsStep),
("set_timesteps", FluxSetTimestepsStep),
("denoise", FluxDenoiseStep), ("denoise", FluxDenoiseStep),
("decode", FluxDecodeStep), ("decode", FluxDecodeStep),
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment