Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b316104d
Unverified
Commit
b316104d
authored
Apr 16, 2025
by
Dhruv Nair
Committed by
GitHub
Apr 16, 2025
Browse files
Fix Hunyuan I2V for `transformers>4.47.1` (#11293)
* update * update
parent
d3b2699a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
8 deletions
+64
-8
src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py
...lines/hunyuan_video/pipeline_hunyuan_video_image2video.py
+64
-8
No files found.
src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py
View file @
b316104d
...
...
@@ -100,6 +100,50 @@ DEFAULT_PROMPT_TEMPLATE = {
}
def
_expand_input_ids_with_image_tokens
(
text_input_ids
,
prompt_attention_mask
,
max_sequence_length
,
image_token_index
,
image_emb_len
,
image_emb_start
,
image_emb_end
,
pad_token_id
,
):
special_image_token_mask
=
text_input_ids
==
image_token_index
num_special_image_tokens
=
torch
.
sum
(
special_image_token_mask
,
dim
=-
1
)
batch_indices
,
non_image_indices
=
torch
.
where
(
text_input_ids
!=
image_token_index
)
max_expanded_length
=
max_sequence_length
+
(
num_special_image_tokens
.
max
()
*
(
image_emb_len
-
1
))
new_token_positions
=
torch
.
cumsum
((
special_image_token_mask
*
(
image_emb_len
-
1
)
+
1
),
-
1
)
-
1
text_to_overwrite
=
new_token_positions
[
batch_indices
,
non_image_indices
]
expanded_input_ids
=
torch
.
full
(
(
text_input_ids
.
shape
[
0
],
max_expanded_length
),
pad_token_id
,
dtype
=
text_input_ids
.
dtype
,
device
=
text_input_ids
.
device
,
)
expanded_input_ids
[
batch_indices
,
text_to_overwrite
]
=
text_input_ids
[
batch_indices
,
non_image_indices
]
expanded_input_ids
[
batch_indices
,
image_emb_start
:
image_emb_end
]
=
image_token_index
expanded_attention_mask
=
torch
.
zeros
(
(
text_input_ids
.
shape
[
0
],
max_expanded_length
),
dtype
=
prompt_attention_mask
.
dtype
,
device
=
prompt_attention_mask
.
device
,
)
attn_batch_indices
,
attention_indices
=
torch
.
where
(
expanded_input_ids
!=
pad_token_id
)
expanded_attention_mask
[
attn_batch_indices
,
attention_indices
]
=
1.0
expanded_attention_mask
=
expanded_attention_mask
.
to
(
prompt_attention_mask
.
dtype
)
position_ids
=
(
expanded_attention_mask
.
cumsum
(
-
1
)
-
1
).
masked_fill_
((
expanded_attention_mask
==
0
),
1
)
return
{
"input_ids"
:
expanded_input_ids
,
"attention_mask"
:
expanded_attention_mask
,
"position_ids"
:
position_ids
,
}
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def
retrieve_timesteps
(
scheduler
,
...
...
@@ -251,6 +295,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
prompt
=
[
prompt_template
[
"template"
].
format
(
p
)
for
p
in
prompt
]
crop_start
=
prompt_template
.
get
(
"crop_start"
,
None
)
image_emb_len
=
prompt_template
.
get
(
"image_emb_len"
,
576
)
image_emb_start
=
prompt_template
.
get
(
"image_emb_start"
,
5
)
image_emb_end
=
prompt_template
.
get
(
"image_emb_end"
,
581
)
double_return_token_id
=
prompt_template
.
get
(
"double_return_token_id"
,
271
)
if
crop_start
is
None
:
prompt_template_input
=
self
.
tokenizer
(
prompt_template
[
"template"
],
...
...
@@ -280,19 +330,25 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
image_embeds
=
self
.
image_processor
(
image
,
return_tensors
=
"pt"
).
pixel_values
.
to
(
device
)
image_token_index
=
self
.
text_encoder
.
config
.
image_token_index
pad_token_id
=
self
.
text_encoder
.
config
.
pad_token_id
expanded_inputs
=
_expand_input_ids_with_image_tokens
(
text_input_ids
,
prompt_attention_mask
,
max_sequence_length
,
image_token_index
,
image_emb_len
,
image_emb_start
,
image_emb_end
,
pad_token_id
,
)
prompt_embeds
=
self
.
text_encoder
(
input_ids
=
text_input_ids
,
attention_mask
=
prompt_attention_mask
,
pixel_values
=
image_embeds
,
**
expanded_inputs
,
pixel_value
=
image_embeds
,
output_hidden_states
=
True
,
).
hidden_states
[
-
(
num_hidden_layers_to_skip
+
1
)]
prompt_embeds
=
prompt_embeds
.
to
(
dtype
=
dtype
)
image_emb_len
=
prompt_template
.
get
(
"image_emb_len"
,
576
)
image_emb_start
=
prompt_template
.
get
(
"image_emb_start"
,
5
)
image_emb_end
=
prompt_template
.
get
(
"image_emb_end"
,
581
)
double_return_token_id
=
prompt_template
.
get
(
"double_return_token_id"
,
271
)
if
crop_start
is
not
None
and
crop_start
>
0
:
text_crop_start
=
crop_start
-
1
+
image_emb_len
batch_indices
,
last_double_return_token_indices
=
torch
.
where
(
text_input_ids
==
double_return_token_id
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment