Unverified Commit 05679329 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[Hi Dream] follow-up (#11296)

* add
parent 29d2afbf
...@@ -8,7 +8,7 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -8,7 +8,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention from ..attention import Attention
from ..embeddings import TimestepEmbedding, Timesteps from ..embeddings import TimestepEmbedding, Timesteps
...@@ -686,46 +686,108 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -686,46 +686,108 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
x = torch.cat(x_arr, dim=0) x = torch.cat(x_arr, dim=0)
return x return x
def patchify(self, x, max_seq, img_sizes=None): def patchify(self, hidden_states):
pz2 = self.config.patch_size * self.config.patch_size batch_size, channels, height, width = hidden_states.shape
if isinstance(x, torch.Tensor): patch_size = self.config.patch_size
B, C = x.shape[0], x.shape[1] patch_height, patch_width = height // patch_size, width // patch_size
device = x.device device = hidden_states.device
dtype = x.dtype dtype = hidden_states.dtype
# create img_sizes
img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)
# create hidden_states_masks
if hidden_states.shape[-2] != hidden_states.shape[-1]:
hidden_states_masks = torch.zeros((batch_size, self.max_seq), dtype=dtype, device=device)
hidden_states_masks[:, : patch_height * patch_width] = 1.0
else: else:
B, C = len(x), x[0].shape[0] hidden_states_masks = None
device = x[0].device
dtype = x[0].dtype
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
if img_sizes is not None: # create img_ids
for i, img_size in enumerate(img_sizes): img_ids = torch.zeros(patch_height, patch_width, 3, device=device)
x_masks[i, 0 : img_size[0] * img_size[1]] = 1 row_indices = torch.arange(patch_height, device=device)[:, None]
B, C, S, _ = x.shape col_indices = torch.arange(patch_width, device=device)[None, :]
x = x.permute(0, 2, 3, 1).reshape(B, S, pz2 * C) img_ids[..., 1] = img_ids[..., 1] + row_indices
elif isinstance(x, torch.Tensor): img_ids[..., 2] = img_ids[..., 2] + col_indices
B, C, Hp1, Wp2 = x.shape img_ids = img_ids.reshape(patch_height * patch_width, -1)
pH, pW = Hp1 // self.config.patch_size, Wp2 // self.config.patch_size
x = x.reshape(B, C, pH, self.config.patch_size, pW, self.config.patch_size) if hidden_states.shape[-2] != hidden_states.shape[-1]:
x = x.permute(0, 2, 4, 3, 5, 1) # Handle non-square latents
x = x.reshape(B, pH * pW, self.config.patch_size * self.config.patch_size * C) img_ids_pad = torch.zeros(self.max_seq, 3, device=device)
img_sizes = [[pH, pW]] * B img_ids_pad[: patch_height * patch_width, :] = img_ids
x_masks = None img_ids = img_ids_pad.unsqueeze(0).repeat(batch_size, 1, 1)
else: else:
raise NotImplementedError img_ids = img_ids.unsqueeze(0).repeat(batch_size, 1, 1)
return x, x_masks, img_sizes
# patchify hidden_states
if hidden_states.shape[-2] != hidden_states.shape[-1]:
# Handle non-square latents
out = torch.zeros(
(batch_size, channels, self.max_seq, patch_size * patch_size),
dtype=dtype,
device=device,
)
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height, patch_size, patch_width, patch_size
)
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height * patch_width, patch_size * patch_size
)
out[:, :, 0 : patch_height * patch_width] = hidden_states
hidden_states = out
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
batch_size, self.max_seq, patch_size * patch_size * channels
)
else:
# Handle square latents
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height, patch_size, patch_width, patch_size
)
hidden_states = hidden_states.permute(0, 2, 4, 3, 5, 1)
hidden_states = hidden_states.reshape(
batch_size, patch_height * patch_width, patch_size * patch_size * channels
)
return hidden_states, hidden_states_masks, img_sizes, img_ids
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
timesteps: torch.LongTensor = None, timesteps: torch.LongTensor = None,
encoder_hidden_states: torch.Tensor = None, encoder_hidden_states_t5: torch.Tensor = None,
encoder_hidden_states_llama3: torch.Tensor = None,
pooled_embeds: torch.Tensor = None, pooled_embeds: torch.Tensor = None,
img_sizes: Optional[List[Tuple[int, int]]] = None,
img_ids: Optional[torch.Tensor] = None, img_ids: Optional[torch.Tensor] = None,
img_sizes: Optional[List[Tuple[int, int]]] = None,
hidden_states_masks: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
**kwargs,
): ):
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if encoder_hidden_states is not None:
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
deprecate("encoder_hidden_states", "0.34.0", deprecation_message)
encoder_hidden_states_t5 = encoder_hidden_states[0]
encoder_hidden_states_llama3 = encoder_hidden_states[1]
if img_ids is not None and img_sizes is not None and hidden_states_masks is None:
deprecation_message = (
"Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
)
deprecate("img_ids", "0.34.0", deprecation_message)
if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
elif hidden_states_masks is not None and hidden_states.ndim != 3:
raise ValueError(
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
)
if attention_kwargs is not None: if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy() attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0) lora_scale = attention_kwargs.pop("scale", 1.0)
...@@ -745,42 +807,19 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -745,42 +807,19 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype hidden_states_type = hidden_states.dtype
if hidden_states.shape[-2] != hidden_states.shape[-1]: # Patchify the input
B, C, H, W = hidden_states.shape if hidden_states_masks is None:
patch_size = self.config.patch_size hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)
pH, pW = H // patch_size, W // patch_size
out = torch.zeros( # Embed the hidden states
(B, C, self.max_seq, patch_size * patch_size), hidden_states = self.x_embedder(hidden_states)
dtype=hidden_states.dtype,
device=hidden_states.device,
)
hidden_states = hidden_states.reshape(B, C, pH, patch_size, pW, patch_size)
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
hidden_states = hidden_states.reshape(B, C, pH * pW, patch_size * patch_size)
out[:, :, 0 : pH * pW] = hidden_states
hidden_states = out
# 0. time # 0. time
timesteps = self.t_embedder(timesteps, hidden_states_type) timesteps = self.t_embedder(timesteps, hidden_states_type)
p_embedder = self.p_embedder(pooled_embeds) p_embedder = self.p_embedder(pooled_embeds)
temb = timesteps + p_embedder temb = timesteps + p_embedder
hidden_states, hidden_states_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers]
if hidden_states_masks is None:
pH, pW = img_sizes[0]
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
img_ids = (
img_ids.reshape(img_ids.shape[0] * img_ids.shape[1], img_ids.shape[2])
.unsqueeze(0)
.repeat(batch_size, 1, 1)
)
hidden_states = self.x_embedder(hidden_states)
T5_encoder_hidden_states = encoder_hidden_states[0]
encoder_hidden_states = encoder_hidden_states[-1]
encoder_hidden_states = [encoder_hidden_states[k] for k in self.config.llama_layers]
if self.caption_projection is not None: if self.caption_projection is not None:
new_encoder_hidden_states = [] new_encoder_hidden_states = []
...@@ -789,9 +828,9 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -789,9 +828,9 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
new_encoder_hidden_states.append(enc_hidden_state) new_encoder_hidden_states.append(enc_hidden_state)
encoder_hidden_states = new_encoder_hidden_states encoder_hidden_states = new_encoder_hidden_states
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5)
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(T5_encoder_hidden_states) encoder_hidden_states.append(encoder_hidden_states_t5)
txt_ids = torch.zeros( txt_ids = torch.zeros(
batch_size, batch_size,
......
...@@ -15,7 +15,7 @@ from transformers import ( ...@@ -15,7 +15,7 @@ from transformers import (
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, HiDreamImageTransformer2DModel from ...models import AutoencoderKL, HiDreamImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import HiDreamImagePipelineOutput from .pipeline_output import HiDreamImagePipelineOutput
...@@ -38,9 +38,6 @@ EXAMPLE_DOC_STRING = """ ...@@ -38,9 +38,6 @@ EXAMPLE_DOC_STRING = """
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM >>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline >>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
>>> scheduler = UniPCMultistepScheduler(
... flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True
... )
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") >>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained( >>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
...@@ -52,7 +49,6 @@ EXAMPLE_DOC_STRING = """ ...@@ -52,7 +49,6 @@ EXAMPLE_DOC_STRING = """
>>> pipe = HiDreamImagePipeline.from_pretrained( >>> pipe = HiDreamImagePipeline.from_pretrained(
... "HiDream-ai/HiDream-I1-Full", ... "HiDream-ai/HiDream-I1-Full",
... scheduler=scheduler,
... tokenizer_4=tokenizer_4, ... tokenizer_4=tokenizer_4,
... text_encoder_4=text_encoder_4, ... text_encoder_4=text_encoder_4,
... torch_dtype=torch.bfloat16, ... torch_dtype=torch.bfloat16,
...@@ -148,7 +144,7 @@ def retrieve_timesteps( ...@@ -148,7 +144,7 @@ def retrieve_timesteps(
class HiDreamImagePipeline(DiffusionPipeline): class HiDreamImagePipeline(DiffusionPipeline):
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae" model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"] _callback_tensor_inputs = ["latents", "prompt_embeds_t5", "prompt_embeds_llama3", "pooled_prompt_embeds"]
def __init__( def __init__(
self, self,
...@@ -309,10 +305,10 @@ class HiDreamImagePipeline(DiffusionPipeline): ...@@ -309,10 +305,10 @@ class HiDreamImagePipeline(DiffusionPipeline):
def encode_prompt( def encode_prompt(
self, self,
prompt: Union[str, List[str]], prompt: Optional[Union[str, List[str]]] = None,
prompt_2: Union[str, List[str]], prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Union[str, List[str]], prompt_3: Optional[Union[str, List[str]]] = None,
prompt_4: Union[str, List[str]], prompt_4: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
...@@ -321,8 +317,10 @@ class HiDreamImagePipeline(DiffusionPipeline): ...@@ -321,8 +317,10 @@ class HiDreamImagePipeline(DiffusionPipeline):
negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None,
negative_prompt_4: Optional[Union[str, List[str]]] = None, negative_prompt_4: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None, prompt_embeds_t5: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds_llama3: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds_t5: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds_llama3: Optional[List[torch.FloatTensor]] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128, max_sequence_length: int = 128,
...@@ -332,120 +330,177 @@ class HiDreamImagePipeline(DiffusionPipeline): ...@@ -332,120 +330,177 @@ class HiDreamImagePipeline(DiffusionPipeline):
if prompt is not None: if prompt is not None:
batch_size = len(prompt) batch_size = len(prompt)
else: else:
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0] batch_size = pooled_prompt_embeds.shape[0]
prompt_embeds, pooled_prompt_embeds = self._encode_prompt( device = device or self._execution_device
prompt=prompt,
prompt_2=prompt_2, if pooled_prompt_embeds is None:
prompt_3=prompt_3, pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
prompt_4=prompt_4, self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
max_sequence_length=max_sequence_length,
) )
if do_classifier_free_guidance and negative_prompt_embeds is None: if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
negative_prompt = negative_prompt or "" negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_3 = negative_prompt_3 or negative_prompt
negative_prompt_4 = negative_prompt_4 or negative_prompt
# normalize str to list if len(negative_prompt) > 1 and len(negative_prompt) != batch_size:
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt raise ValueError(f"negative_prompt must be of length 1 or {batch_size}")
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
negative_prompt_3 = (
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
)
negative_prompt_4 = (
batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
)
if prompt is not None and type(prompt) is not type(negative_prompt): negative_pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
raise TypeError( self.tokenizer, self.text_encoder, negative_prompt, max_sequence_length, device, dtype
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
) )
negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt( if negative_pooled_prompt_embeds_1.shape[0] == 1 and batch_size > 1:
prompt=negative_prompt, negative_pooled_prompt_embeds_1 = negative_pooled_prompt_embeds_1.repeat(batch_size, 1)
prompt_2=negative_prompt_2,
prompt_3=negative_prompt_3,
prompt_4=negative_prompt_4,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
max_sequence_length=max_sequence_length,
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
def _encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_3: Union[str, List[str]],
prompt_4: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128,
):
device = device or self._execution_device
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
if pooled_prompt_embeds is None: if pooled_prompt_embeds is None:
prompt_2 = prompt_2 or prompt prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( if len(prompt_2) > 1 and len(prompt_2) != batch_size:
self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype raise ValueError(f"prompt_2 must be of length 1 or {batch_size}")
)
pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, device, dtype self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, device, dtype
) )
if pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1:
pooled_prompt_embeds_2 = pooled_prompt_embeds_2.repeat(batch_size, 1)
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
negative_prompt_2 = negative_prompt_2 or negative_prompt
negative_prompt_2 = [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
if len(negative_prompt_2) > 1 and len(negative_prompt_2) != batch_size:
raise ValueError(f"negative_prompt_2 must be of length 1 or {batch_size}")
negative_pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
self.tokenizer_2, self.text_encoder_2, negative_prompt_2, max_sequence_length, device, dtype
)
if negative_pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1:
negative_pooled_prompt_embeds_2 = negative_pooled_prompt_embeds_2.repeat(batch_size, 1)
if pooled_prompt_embeds is None:
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) negative_pooled_prompt_embeds = torch.cat(
[negative_pooled_prompt_embeds_1, negative_pooled_prompt_embeds_2], dim=-1
)
if prompt_embeds is None: if prompt_embeds_t5 is None:
prompt_3 = prompt_3 or prompt prompt_3 = prompt_3 or prompt
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
if len(prompt_3) > 1 and len(prompt_3) != batch_size:
raise ValueError(f"prompt_3 must be of length 1 or {batch_size}")
prompt_embeds_t5 = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype)
if prompt_embeds_t5.shape[0] == 1 and batch_size > 1:
prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1)
if do_classifier_free_guidance and negative_prompt_embeds_t5 is None:
negative_prompt_3 = negative_prompt_3 or negative_prompt
negative_prompt_3 = [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
if len(negative_prompt_3) > 1 and len(negative_prompt_3) != batch_size:
raise ValueError(f"negative_prompt_3 must be of length 1 or {batch_size}")
negative_prompt_embeds_t5 = self._get_t5_prompt_embeds(
negative_prompt_3, max_sequence_length, device, dtype
)
if negative_prompt_embeds_t5.shape[0] == 1 and batch_size > 1:
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1)
if prompt_embeds_llama3 is None:
prompt_4 = prompt_4 or prompt prompt_4 = prompt_4 or prompt
prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4 prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4
t5_prompt_embeds = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype) if len(prompt_4) > 1 and len(prompt_4) != batch_size:
llama3_prompt_embeds = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype) raise ValueError(f"prompt_4 must be of length 1 or {batch_size}")
prompt_embeds_llama3 = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype)
if prompt_embeds_llama3.shape[0] == 1 and batch_size > 1:
prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
_, seq_len, _ = t5_prompt_embeds.shape if do_classifier_free_guidance and negative_prompt_embeds_llama3 is None:
t5_prompt_embeds = t5_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_4 = negative_prompt_4 or negative_prompt
t5_prompt_embeds = t5_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_4 = [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
if len(negative_prompt_4) > 1 and len(negative_prompt_4) != batch_size:
raise ValueError(f"negative_prompt_4 must be of length 1 or {batch_size}")
_, _, seq_len, dim = llama3_prompt_embeds.shape negative_prompt_embeds_llama3 = self._get_llama3_prompt_embeds(
llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, 1, num_images_per_prompt, 1) negative_prompt_4, max_sequence_length, device, dtype
llama3_prompt_embeds = llama3_prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim) )
prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] if negative_prompt_embeds_llama3.shape[0] == 1 and batch_size > 1:
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
return prompt_embeds, pooled_prompt_embeds # duplicate pooled_prompt_embeds for each generation per prompt
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
# duplicate t5_prompt_embeds for batch_size and num_images_per_prompt
bs_embed, seq_len, _ = prompt_embeds_t5.shape
if bs_embed == 1 and batch_size > 1:
prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1)
elif bs_embed > 1 and bs_embed != batch_size:
raise ValueError(f"cannot duplicate prompt_embeds_t5 of batch size {bs_embed}")
prompt_embeds_t5 = prompt_embeds_t5.repeat(1, num_images_per_prompt, 1)
prompt_embeds_t5 = prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1)
# duplicate llama3_prompt_embeds for batch_size and num_images_per_prompt
_, bs_embed, seq_len, dim = prompt_embeds_llama3.shape
if bs_embed == 1 and batch_size > 1:
prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
elif bs_embed > 1 and bs_embed != batch_size:
raise ValueError(f"cannot duplicate prompt_embeds_llama3 of batch size {bs_embed}")
prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1)
prompt_embeds_llama3 = prompt_embeds_llama3.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
if do_classifier_free_guidance:
# duplicate negative_pooled_prompt_embeds for batch_size and num_images_per_prompt
bs_embed, seq_len = negative_pooled_prompt_embeds.shape
if bs_embed == 1 and batch_size > 1:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1)
elif bs_embed > 1 and bs_embed != batch_size:
raise ValueError(f"cannot duplicate negative_pooled_prompt_embeds of batch size {bs_embed}")
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
# duplicate negative_t5_prompt_embeds for batch_size and num_images_per_prompt
bs_embed, seq_len, _ = negative_prompt_embeds_t5.shape
if bs_embed == 1 and batch_size > 1:
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1)
elif bs_embed > 1 and bs_embed != batch_size:
raise ValueError(f"cannot duplicate negative_prompt_embeds_t5 of batch size {bs_embed}")
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1)
# duplicate negative_prompt_embeds_llama3 for batch_size and num_images_per_prompt
_, bs_embed, seq_len, dim = negative_prompt_embeds_llama3.shape
if bs_embed == 1 and batch_size > 1:
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
elif bs_embed > 1 and bs_embed != batch_size:
raise ValueError(f"cannot duplicate negative_prompt_embeds_llama3 of batch size {bs_embed}")
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1)
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.view(
-1, batch_size * num_images_per_prompt, seq_len, dim
)
return (
prompt_embeds_t5,
negative_prompt_embeds_t5,
prompt_embeds_llama3,
negative_prompt_embeds_llama3,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
def enable_vae_slicing(self): def enable_vae_slicing(self):
r""" r"""
...@@ -476,6 +531,115 @@ class HiDreamImagePipeline(DiffusionPipeline): ...@@ -476,6 +531,115 @@ class HiDreamImagePipeline(DiffusionPipeline):
""" """
self.vae.disable_tiling() self.vae.disable_tiling()
def check_inputs(
self,
prompt,
prompt_2,
prompt_3,
prompt_4,
negative_prompt=None,
negative_prompt_2=None,
negative_prompt_3=None,
negative_prompt_4=None,
prompt_embeds_t5=None,
prompt_embeds_llama3=None,
negative_prompt_embeds_t5=None,
negative_prompt_embeds_llama3=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and pooled_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and pooled_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_3 is not None and prompt_embeds_t5 is not None:
raise ValueError(
f"Cannot forward both `prompt_3`: {prompt_3} and `prompt_embeds_t5`: {prompt_embeds_t5}. Please make sure to"
" only forward one of the two."
)
elif prompt_4 is not None and prompt_embeds_llama3 is not None:
raise ValueError(
f"Cannot forward both `prompt_4`: {prompt_4} and `prompt_embeds_llama3`: {prompt_embeds_llama3}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and pooled_prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `pooled_prompt_embeds`. Cannot leave both `prompt` and `pooled_prompt_embeds` undefined."
)
elif prompt is None and prompt_embeds_t5 is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds_t5`. Cannot leave both `prompt` and `prompt_embeds_t5` undefined."
)
elif prompt is None and prompt_embeds_llama3 is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds_llama3`. Cannot leave both `prompt` and `prompt_embeds_llama3` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
elif prompt_4 is not None and (not isinstance(prompt_4, str) and not isinstance(prompt_4, list)):
raise ValueError(f"`prompt_4` has to be of type `str` or `list` but is {type(prompt_4)}")
if negative_prompt is not None and negative_pooled_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_pooled_prompt_embeds`:"
f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two."
)
elif negative_prompt_2 is not None and negative_pooled_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_pooled_prompt_embeds`:"
f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two."
)
elif negative_prompt_3 is not None and negative_prompt_embeds_t5 is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds_t5`:"
f" {negative_prompt_embeds_t5}. Please make sure to only forward one of the two."
)
elif negative_prompt_4 is not None and negative_prompt_embeds_llama3 is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_4`: {negative_prompt_4} and `negative_prompt_embeds_llama3`:"
f" {negative_prompt_embeds_llama3}. Please make sure to only forward one of the two."
)
if pooled_prompt_embeds is not None and negative_pooled_prompt_embeds is not None:
if pooled_prompt_embeds.shape != negative_pooled_prompt_embeds.shape:
raise ValueError(
"`pooled_prompt_embeds` and `negative_pooled_prompt_embeds` must have the same shape when passed directly, but"
f" got: `pooled_prompt_embeds` {pooled_prompt_embeds.shape} != `negative_pooled_prompt_embeds`"
f" {negative_pooled_prompt_embeds.shape}."
)
if prompt_embeds_t5 is not None and negative_prompt_embeds_t5 is not None:
if prompt_embeds_t5.shape != negative_prompt_embeds_t5.shape:
raise ValueError(
"`prompt_embeds_t5` and `negative_prompt_embeds_t5` must have the same shape when passed directly, but"
f" got: `prompt_embeds_t5` {prompt_embeds_t5.shape} != `negative_prompt_embeds_t5`"
f" {negative_prompt_embeds_t5.shape}."
)
if prompt_embeds_llama3 is not None and negative_prompt_embeds_llama3 is not None:
if prompt_embeds_llama3.shape != negative_prompt_embeds_llama3.shape:
raise ValueError(
"`prompt_embeds_llama3` and `negative_prompt_embeds_llama3` must have the same shape when passed directly, but"
f" got: `prompt_embeds_llama3` {prompt_embeds_llama3.shape} != `negative_prompt_embeds_llama3`"
f" {negative_prompt_embeds_llama3.shape}."
)
def prepare_latents( def prepare_latents(
self, self,
batch_size, batch_size,
...@@ -542,8 +706,10 @@ class HiDreamImagePipeline(DiffusionPipeline): ...@@ -542,8 +706,10 @@ class HiDreamImagePipeline(DiffusionPipeline):
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds_t5: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds_llama3: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_llama3: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
...@@ -552,6 +718,7 @@ class HiDreamImagePipeline(DiffusionPipeline): ...@@ -552,6 +718,7 @@ class HiDreamImagePipeline(DiffusionPipeline):
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 128, max_sequence_length: int = 128,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -649,6 +816,22 @@ class HiDreamImagePipeline(DiffusionPipeline): ...@@ -649,6 +816,22 @@ class HiDreamImagePipeline(DiffusionPipeline):
[`~pipelines.hidream_image.HiDreamImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When [`~pipelines.hidream_image.HiDreamImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated. images. returning a tuple, the first element is a list with the generated. images.
""" """
prompt_embeds = kwargs.get("prompt_embeds", None)
negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None)
if prompt_embeds is not None:
deprecation_message = "The `prompt_embeds` argument is deprecated. Please use `prompt_embeds_t5` and `prompt_embeds_llama3` instead."
deprecate("prompt_embeds", "0.34.0", deprecation_message)
prompt_embeds_t5 = prompt_embeds[0]
prompt_embeds_llama3 = prompt_embeds[1]
if negative_prompt_embeds is not None:
deprecation_message = "The `negative_prompt_embeds` argument is deprecated. Please use `negative_prompt_embeds_t5` and `negative_prompt_embeds_llama3` instead."
deprecate("negative_prompt_embeds", "0.34.0", deprecation_message)
negative_prompt_embeds_t5 = negative_prompt_embeds[0]
negative_prompt_embeds_llama3 = negative_prompt_embeds[1]
height = height or self.default_sample_size * self.vae_scale_factor height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor
...@@ -658,6 +841,25 @@ class HiDreamImagePipeline(DiffusionPipeline): ...@@ -658,6 +841,25 @@ class HiDreamImagePipeline(DiffusionPipeline):
scale = math.sqrt(scale) scale = math.sqrt(scale)
width, height = int(width * scale // division * division), int(height * scale // division * division) width, height = int(width * scale // division * division), int(height * scale // division * division)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
prompt_3,
prompt_4,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
negative_prompt_4=negative_prompt_4,
prompt_embeds_t5=prompt_embeds_t5,
prompt_embeds_llama3=prompt_embeds_llama3,
negative_prompt_embeds_t5=negative_prompt_embeds_t5,
negative_prompt_embeds_llama3=negative_prompt_embeds_llama3,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs self._attention_kwargs = attention_kwargs
self._interrupt = False self._interrupt = False
...@@ -667,17 +869,18 @@ class HiDreamImagePipeline(DiffusionPipeline): ...@@ -667,17 +869,18 @@ class HiDreamImagePipeline(DiffusionPipeline):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt) batch_size = len(prompt)
elif prompt_embeds is not None: elif pooled_prompt_embeds is not None:
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0] batch_size = pooled_prompt_embeds.shape[0]
else:
batch_size = 1
device = self._execution_device device = self._execution_device
# 3. Encode prompt
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
( (
prompt_embeds, prompt_embeds_t5,
negative_prompt_embeds, negative_prompt_embeds_t5,
prompt_embeds_llama3,
negative_prompt_embeds_llama3,
pooled_prompt_embeds, pooled_prompt_embeds,
negative_pooled_prompt_embeds, negative_pooled_prompt_embeds,
) = self.encode_prompt( ) = self.encode_prompt(
...@@ -690,8 +893,10 @@ class HiDreamImagePipeline(DiffusionPipeline): ...@@ -690,8 +893,10 @@ class HiDreamImagePipeline(DiffusionPipeline):
negative_prompt_3=negative_prompt_3, negative_prompt_3=negative_prompt_3,
negative_prompt_4=negative_prompt_4, negative_prompt_4=negative_prompt_4,
do_classifier_free_guidance=self.do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds, prompt_embeds_t5=prompt_embeds_t5,
negative_prompt_embeds=negative_prompt_embeds, prompt_embeds_llama3=prompt_embeds_llama3,
negative_prompt_embeds_t5=negative_prompt_embeds_t5,
negative_prompt_embeds_llama3=negative_prompt_embeds_llama3,
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device, device=device,
...@@ -701,13 +906,8 @@ class HiDreamImagePipeline(DiffusionPipeline): ...@@ -701,13 +906,8 @@ class HiDreamImagePipeline(DiffusionPipeline):
) )
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds_arr = [] prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5], dim=0)
for n, p in zip(negative_prompt_embeds, prompt_embeds): prompt_embeds_llama3 = torch.cat([negative_prompt_embeds_llama3, prompt_embeds_llama3], dim=1)
if len(n.shape) == 3:
prompt_embeds_arr.append(torch.cat([n, p], dim=0))
else:
prompt_embeds_arr.append(torch.cat([n, p], dim=1))
prompt_embeds = prompt_embeds_arr
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 4. Prepare latent variables # 4. Prepare latent variables
...@@ -723,26 +923,6 @@ class HiDreamImagePipeline(DiffusionPipeline): ...@@ -723,26 +923,6 @@ class HiDreamImagePipeline(DiffusionPipeline):
latents, latents,
) )
if latents.shape[-2] != latents.shape[-1]:
B, C, H, W = latents.shape
pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size
img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1)
img_ids = torch.zeros(pH, pW, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :]
img_ids = img_ids.reshape(pH * pW, -1)
img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
img_ids_pad[: pH * pW, :] = img_ids
img_sizes = img_sizes.unsqueeze(0).to(latents.device)
img_ids = img_ids_pad.unsqueeze(0).to(latents.device)
if self.do_classifier_free_guidance:
img_sizes = img_sizes.repeat(2 * B, 1)
img_ids = img_ids.repeat(2 * B, 1, 1)
else:
img_sizes = img_ids = None
# 5. Prepare timesteps # 5. Prepare timesteps
mu = calculate_shift(self.transformer.max_seq) mu = calculate_shift(self.transformer.max_seq)
scheduler_kwargs = {"mu": mu} scheduler_kwargs = {"mu": mu}
...@@ -774,10 +954,9 @@ class HiDreamImagePipeline(DiffusionPipeline): ...@@ -774,10 +954,9 @@ class HiDreamImagePipeline(DiffusionPipeline):
noise_pred = self.transformer( noise_pred = self.transformer(
hidden_states=latent_model_input, hidden_states=latent_model_input,
timesteps=timestep, timesteps=timestep,
encoder_hidden_states=prompt_embeds, encoder_hidden_states_t5=prompt_embeds_t5,
encoder_hidden_states_llama3=prompt_embeds_llama3,
pooled_embeds=pooled_prompt_embeds, pooled_embeds=pooled_prompt_embeds,
img_sizes=img_sizes,
img_ids=img_ids,
return_dict=False, return_dict=False,
)[0] )[0]
noise_pred = -noise_pred noise_pred = -noise_pred
...@@ -803,8 +982,9 @@ class HiDreamImagePipeline(DiffusionPipeline): ...@@ -803,8 +982,9 @@ class HiDreamImagePipeline(DiffusionPipeline):
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents) latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) prompt_embeds_t5 = callback_outputs.pop("prompt_embeds_t5", prompt_embeds_t5)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) prompt_embeds_llama3 = callback_outputs.pop("prompt_embeds_llama3", prompt_embeds_llama3)
pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -43,7 +43,7 @@ enable_full_determinism() ...@@ -43,7 +43,7 @@ enable_full_determinism()
class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = HiDreamImagePipeline pipeline_class = HiDreamImagePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "prompt_embeds", "negative_prompt_embeds"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment