Unverified Commit 3c8b67b3 authored by Colle's avatar Colle Committed by GitHub
Browse files

Flux: pass joint_attention_kwargs when using gradient_checkpointing (#11814)

Flux: pass joint_attention_kwargs when gradient_checkpointing
parent 9feb9464
...@@ -490,6 +490,7 @@ class FluxTransformer2DModel( ...@@ -490,6 +490,7 @@ class FluxTransformer2DModel(
encoder_hidden_states, encoder_hidden_states,
temb, temb,
image_rotary_emb, image_rotary_emb,
joint_attention_kwargs,
) )
else: else:
...@@ -521,6 +522,7 @@ class FluxTransformer2DModel( ...@@ -521,6 +522,7 @@ class FluxTransformer2DModel(
encoder_hidden_states, encoder_hidden_states,
temb, temb,
image_rotary_emb, image_rotary_emb,
joint_attention_kwargs,
) )
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment