Commit f11a2e2a authored by Hyunsung Lee's avatar Hyunsung Lee Committed by muyangli
Browse files

Update utils.py

parent f04b603c
......@@ -245,6 +245,9 @@ class FluxCachedTransformerBlocks(nn.Module):
self.return_hidden_states_only = return_hidden_states_only
self.verbose = verbose
def update_residual_diff_threshold(self, residual_diff_threshold=0.12):
self.residual_diff_threshold = residual_diff_threshold
def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
batch_size = hidden_states.shape[0]
if self.residual_diff_threshold <= 0.0 or batch_size > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment