Unverified Commit 8d6dc2be authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Revert "[Flux] reduce explicit device transfers and typecasting in flux." (#9896)

Revert "[Flux] reduce explicit device transfers and typecasting in flux. (#9817)"

This reverts commit 5588725e.
parent d720b213
...@@ -371,7 +371,7 @@ class FluxPipeline( ...@@ -371,7 +371,7 @@ class FluxPipeline(
unscale_lora_layers(self.text_encoder_2, lora_scale) unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids return prompt_embeds, pooled_prompt_embeds, text_ids
...@@ -427,7 +427,7 @@ class FluxPipeline( ...@@ -427,7 +427,7 @@ class FluxPipeline(
@staticmethod @staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype): def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
...@@ -437,7 +437,7 @@ class FluxPipeline( ...@@ -437,7 +437,7 @@ class FluxPipeline(
latent_image_id_height * latent_image_id_width, latent_image_id_channels latent_image_id_height * latent_image_id_width, latent_image_id_channels
) )
return latent_image_ids return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod @staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width): def _pack_latents(latents, batch_size, num_channels_latents, height, width):
......
...@@ -452,7 +452,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -452,7 +452,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
@staticmethod @staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype): def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
...@@ -462,7 +462,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -462,7 +462,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
latent_image_id_height * latent_image_id_width, latent_image_id_channels latent_image_id_height * latent_image_id_width, latent_image_id_channels
) )
return latent_image_ids return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod @staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
......
...@@ -407,7 +407,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -407,7 +407,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
unscale_lora_layers(self.text_encoder_2, lora_scale) unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids return prompt_embeds, pooled_prompt_embeds, text_ids
...@@ -495,7 +495,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -495,7 +495,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
@staticmethod @staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype): def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
...@@ -505,7 +505,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -505,7 +505,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
latent_image_id_height * latent_image_id_width, latent_image_id_channels latent_image_id_height * latent_image_id_width, latent_image_id_channels
) )
return latent_image_ids return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod @staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
......
...@@ -417,7 +417,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -417,7 +417,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
unscale_lora_layers(self.text_encoder_2, lora_scale) unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids return prompt_embeds, pooled_prompt_embeds, text_ids
...@@ -522,7 +522,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -522,7 +522,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
@staticmethod @staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype): def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
...@@ -532,7 +532,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -532,7 +532,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
latent_image_id_height * latent_image_id_width, latent_image_id_channels latent_image_id_height * latent_image_id_width, latent_image_id_channels
) )
return latent_image_ids return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod @staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
......
...@@ -391,7 +391,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -391,7 +391,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
unscale_lora_layers(self.text_encoder_2, lora_scale) unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids return prompt_embeds, pooled_prompt_embeds, text_ids
...@@ -479,7 +479,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -479,7 +479,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
@staticmethod @staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype): def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
...@@ -489,7 +489,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -489,7 +489,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
latent_image_id_height * latent_image_id_width, latent_image_id_channels latent_image_id_height * latent_image_id_width, latent_image_id_channels
) )
return latent_image_ids return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod @staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
......
...@@ -395,7 +395,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -395,7 +395,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
unscale_lora_layers(self.text_encoder_2, lora_scale) unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids return prompt_embeds, pooled_prompt_embeds, text_ids
...@@ -500,7 +500,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -500,7 +500,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
@staticmethod @staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype): def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
...@@ -510,7 +510,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -510,7 +510,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
latent_image_id_height * latent_image_id_width, latent_image_id_channels latent_image_id_height * latent_image_id_width, latent_image_id_channels
) )
return latent_image_ids return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod @staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment