Unverified Commit b316104d authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Fix Hunyuan I2V for `transformers>4.47.1` (#11293)

* update

* update
parent d3b2699a
......@@ -100,6 +100,50 @@ DEFAULT_PROMPT_TEMPLATE = {
}
def _expand_input_ids_with_image_tokens(
text_input_ids,
prompt_attention_mask,
max_sequence_length,
image_token_index,
image_emb_len,
image_emb_start,
image_emb_end,
pad_token_id,
):
special_image_token_mask = text_input_ids == image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index)
max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1))
new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
expanded_input_ids = torch.full(
(text_input_ids.shape[0], max_expanded_length),
pad_token_id,
dtype=text_input_ids.dtype,
device=text_input_ids.device,
)
expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices]
expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index
expanded_attention_mask = torch.zeros(
(text_input_ids.shape[0], max_expanded_length),
dtype=prompt_attention_mask.dtype,
device=prompt_attention_mask.device,
)
attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id)
expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0
expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype)
position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1)
return {
"input_ids": expanded_input_ids,
"attention_mask": expanded_attention_mask,
"position_ids": position_ids,
}
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
......@@ -251,6 +295,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
prompt = [prompt_template["template"].format(p) for p in prompt]
crop_start = prompt_template.get("crop_start", None)
image_emb_len = prompt_template.get("image_emb_len", 576)
image_emb_start = prompt_template.get("image_emb_start", 5)
image_emb_end = prompt_template.get("image_emb_end", 581)
double_return_token_id = prompt_template.get("double_return_token_id", 271)
if crop_start is None:
prompt_template_input = self.tokenizer(
prompt_template["template"],
......@@ -280,19 +330,25 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
image_token_index = self.text_encoder.config.image_token_index
pad_token_id = self.text_encoder.config.pad_token_id
expanded_inputs = _expand_input_ids_with_image_tokens(
text_input_ids,
prompt_attention_mask,
max_sequence_length,
image_token_index,
image_emb_len,
image_emb_start,
image_emb_end,
pad_token_id,
)
prompt_embeds = self.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_attention_mask,
pixel_values=image_embeds,
**expanded_inputs,
pixel_value=image_embeds,
output_hidden_states=True,
).hidden_states[-(num_hidden_layers_to_skip + 1)]
prompt_embeds = prompt_embeds.to(dtype=dtype)
image_emb_len = prompt_template.get("image_emb_len", 576)
image_emb_start = prompt_template.get("image_emb_start", 5)
image_emb_end = prompt_template.get("image_emb_end", 581)
double_return_token_id = prompt_template.get("double_return_token_id", 271)
if crop_start is not None and crop_start > 0:
text_crop_start = crop_start - 1 + image_emb_len
batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment