Commit b2df4e4e authored by muyangli's avatar muyangli
Browse files

Merge branch 'dev' of github.com:mit-han-lab/nunchaku into dev

parents 821db2a1 706a13f6
...@@ -257,7 +257,7 @@ class SanaCachedTransformerBlocks(nn.Module): ...@@ -257,7 +257,7 @@ class SanaCachedTransformerBlocks(nn.Module):
first_hidden_states_residual = hidden_states - original_hidden_states first_hidden_states_residual = hidden_states - original_hidden_states
del original_hidden_states del original_hidden_states
can_use_cache = get_can_use_cache( can_use_cache, _ = get_can_use_cache(
first_hidden_states_residual, first_hidden_states_residual,
threshold=self.residual_diff_threshold, threshold=self.residual_diff_threshold,
parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False), parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False),
...@@ -272,7 +272,7 @@ class SanaCachedTransformerBlocks(nn.Module): ...@@ -272,7 +272,7 @@ class SanaCachedTransformerBlocks(nn.Module):
else: else:
if self.verbose: if self.verbose:
print("Cache miss!!!") print("Cache miss!!!")
set_buffer("first_hidden_states_residual", first_hidden_states_residual) set_buffer("first_multi_hidden_states_residual", first_hidden_states_residual)
del first_hidden_states_residual del first_hidden_states_residual
hidden_states, hidden_states_residual = self.call_remaining_transformer_blocks( hidden_states, hidden_states_residual = self.call_remaining_transformer_blocks(
...@@ -284,7 +284,7 @@ class SanaCachedTransformerBlocks(nn.Module): ...@@ -284,7 +284,7 @@ class SanaCachedTransformerBlocks(nn.Module):
post_patch_height=post_patch_height, post_patch_height=post_patch_height,
post_patch_width=post_patch_width, post_patch_width=post_patch_width,
) )
set_buffer("hidden_states_residual", hidden_states_residual) set_buffer("multi_hidden_states_residual", hidden_states_residual)
torch._dynamo.graph_break() torch._dynamo.graph_break()
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment