Unverified Commit 3c8b67b3 authored by Colle's avatar Colle Committed by GitHub
Browse files

Flux: pass joint_attention_kwargs when using gradient_checkpointing (#11814)

Flux: pass joint_attention_kwargs when gradient_checkpointing
parent 9feb9464
......@@ -490,6 +490,7 @@ class FluxTransformer2DModel(
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
......@@ -521,6 +522,7 @@ class FluxTransformer2DModel(
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment