Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
bf0813a6
Commit
bf0813a6
authored
Mar 18, 2025
by
Hyunsung Lee
Committed by
Zhekai Zhang
Apr 01, 2025
Browse files
Add SanaModel caching
parent
65d7e47a
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
313 additions
and
51 deletions
+313
-51
examples/int4-sana_1600m-cache.py
examples/int4-sana_1600m-cache.py
+32
-0
nunchaku/caching/diffusers_adapters/__init__.py
nunchaku/caching/diffusers_adapters/__init__.py
+2
-0
nunchaku/caching/diffusers_adapters/flux.py
nunchaku/caching/diffusers_adapters/flux.py
+1
-1
nunchaku/caching/diffusers_adapters/sana.py
nunchaku/caching/diffusers_adapters/sana.py
+50
-0
nunchaku/caching/utils.py
nunchaku/caching/utils.py
+125
-8
nunchaku/csrc/sana.h
nunchaku/csrc/sana.h
+19
-17
nunchaku/models/transformers/transformer_sana.py
nunchaku/models/transformers/transformer_sana.py
+56
-0
src/SanaModel.cpp
src/SanaModel.cpp
+27
-24
src/SanaModel.h
src/SanaModel.h
+1
-1
No files found.
examples/int4-sana_1600m-cache.py
0 → 100644
View file @
bf0813a6
import
torch
from
diffusers
import
SanaPipeline
from
nunchaku
import
NunchakuSanaTransformer2DModel
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
transformer
=
NunchakuSanaTransformer2DModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-sana-1600m"
)
pipe
=
SanaPipeline
.
from_pretrained
(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
,
transformer
=
transformer
,
variant
=
"bf16"
,
torch_dtype
=
torch
.
bfloat16
,
).
to
(
"cuda"
)
pipe
.
vae
.
to
(
torch
.
bfloat16
)
pipe
.
text_encoder
.
to
(
torch
.
bfloat16
)
apply_cache_on_pipe
(
pipe
,
residual_diff_threshold
=
0.25
)
# WarmUp
prompt
=
"A cute 🐼 eating 🎋, ink drawing style"
image
=
pipe
(
prompt
=
prompt
,
height
=
1024
,
width
=
1024
,
guidance_scale
=
4.5
,
num_inference_steps
=
20
,
generator
=
torch
.
Generator
().
manual_seed
(
42
),
).
images
[
0
]
image
.
save
(
"sana_1600m.png"
)
nunchaku/caching/diffusers_adapters/__init__.py
View file @
bf0813a6
...
@@ -7,6 +7,8 @@ def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
...
@@ -7,6 +7,8 @@ def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
pipe_cls_name
=
pipe
.
__class__
.
__name__
pipe_cls_name
=
pipe
.
__class__
.
__name__
if
pipe_cls_name
.
startswith
(
"Flux"
):
if
pipe_cls_name
.
startswith
(
"Flux"
):
from
.flux
import
apply_cache_on_pipe
as
apply_cache_on_pipe_fn
from
.flux
import
apply_cache_on_pipe
as
apply_cache_on_pipe_fn
elif
pipe_cls_name
.
startswith
(
"Sana"
):
from
.sana
import
apply_cache_on_pipe
as
apply_cache_on_pipe_fn
else
:
else
:
raise
ValueError
(
f
"Unknown pipeline class name:
{
pipe_cls_name
}
"
)
raise
ValueError
(
f
"Unknown pipeline class name:
{
pipe_cls_name
}
"
)
return
apply_cache_on_pipe_fn
(
pipe
,
*
args
,
**
kwargs
)
return
apply_cache_on_pipe_fn
(
pipe
,
*
args
,
**
kwargs
)
nunchaku/caching/diffusers_adapters/flux.py
View file @
bf0813a6
...
@@ -13,7 +13,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_
...
@@ -13,7 +13,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_
cached_transformer_blocks
=
torch
.
nn
.
ModuleList
(
cached_transformer_blocks
=
torch
.
nn
.
ModuleList
(
[
[
utils
.
CachedTransformerBlocks
(
utils
.
Flux
CachedTransformerBlocks
(
transformer
=
transformer
,
transformer
=
transformer
,
residual_diff_threshold
=
residual_diff_threshold
,
residual_diff_threshold
=
residual_diff_threshold
,
return_hidden_states_first
=
False
,
return_hidden_states_first
=
False
,
...
...
nunchaku/caching/diffusers_adapters/sana.py
0 → 100644
View file @
bf0813a6
import
functools
import
unittest
import
torch
from
diffusers
import
DiffusionPipeline
,
SanaTransformer2DModel
from
...caching
import
utils
def
apply_cache_on_transformer
(
transformer
:
SanaTransformer2DModel
,
*
,
residual_diff_threshold
=
0.12
):
if
getattr
(
transformer
,
"_is_cached"
,
False
):
return
transformer
cached_transformer_blocks
=
torch
.
nn
.
ModuleList
(
[
utils
.
SanaCachedTransformerBlocks
(
transformer
=
transformer
,
residual_diff_threshold
=
residual_diff_threshold
,
)
]
)
original_forward
=
transformer
.
forward
@
functools
.
wraps
(
original_forward
)
def
new_forward
(
self
,
*
args
,
**
kwargs
):
with
unittest
.
mock
.
patch
.
object
(
self
,
"transformer_blocks"
,
cached_transformer_blocks
):
return
original_forward
(
*
args
,
**
kwargs
)
transformer
.
forward
=
new_forward
.
__get__
(
transformer
)
transformer
.
_is_cached
=
True
return
transformer
def
apply_cache_on_pipe
(
pipe
:
DiffusionPipeline
,
*
,
shallow_patch
:
bool
=
False
,
**
kwargs
):
if
not
getattr
(
pipe
,
"_is_cached"
,
False
):
original_call
=
pipe
.
__class__
.
__call__
@
functools
.
wraps
(
original_call
)
def
new_call
(
self
,
*
args
,
**
kwargs
):
with
utils
.
cache_context
(
utils
.
create_cache_context
()):
return
original_call
(
self
,
*
args
,
**
kwargs
)
pipe
.
__class__
.
__call__
=
new_call
pipe
.
__class__
.
_is_cached
=
True
if
not
shallow_patch
:
apply_cache_on_transformer
(
pipe
.
transformer
,
**
kwargs
)
return
pipe
nunchaku/caching/utils.py
View file @
bf0813a6
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
contextlib
import
contextlib
import
dataclasses
import
dataclasses
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
DefaultDict
,
Dict
from
typing
import
DefaultDict
,
Dict
,
Optional
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -81,19 +81,20 @@ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
...
@@ -81,19 +81,20 @@ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
@
torch
.
compiler
.
disable
@
torch
.
compiler
.
disable
def
apply_prev_hidden_states_residual
(
def
apply_prev_hidden_states_residual
(
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_states_residual
=
get_buffer
(
"hidden_states_residual"
)
hidden_states_residual
=
get_buffer
(
"hidden_states_residual"
)
assert
hidden_states_residual
is
not
None
,
"hidden_states_residual must be set before"
assert
hidden_states_residual
is
not
None
,
"hidden_states_residual must be set before"
hidden_states
=
hidden_states_residual
+
hidden_states
hidden_states
=
hidden_states_residual
+
hidden_states
hidden_states
=
hidden_states
.
contiguous
()
if
encoder_hidden_states
is
not
None
:
encoder_hidden_states_residual
=
get_buffer
(
"encoder_hidden_states_residual"
)
encoder_hidden_states_residual
=
get_buffer
(
"encoder_hidden_states_residual"
)
assert
encoder_hidden_states_residual
is
not
None
,
"encoder_hidden_states_residual must be set before"
assert
encoder_hidden_states_residual
is
not
None
,
"encoder_hidden_states_residual must be set before"
encoder_hidden_states
=
encoder_hidden_states_residual
+
encoder_hidden_states
encoder_hidden_states
=
encoder_hidden_states_residual
+
encoder_hidden_states
hidden_states
=
hidden_states
.
contiguous
()
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
return
hidden_states
,
encoder_hidden_states
return
hidden_states
,
encoder_hidden_states
...
@@ -108,8 +109,124 @@ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=Fals
...
@@ -108,8 +109,124 @@ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=Fals
)
)
return
can_use_cache
return
can_use_cache
class
SanaCachedTransformerBlocks
(
nn
.
Module
):
def
__init__
(
self
,
*
,
transformer
=
None
,
residual_diff_threshold
,
verbose
:
bool
=
False
,
):
super
().
__init__
()
self
.
transformer
=
transformer
self
.
transformer_blocks
=
transformer
.
transformer_blocks
self
.
residual_diff_threshold
=
residual_diff_threshold
self
.
verbose
=
verbose
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_hidden_states
,
encoder_attention_mask
=
None
,
timestep
=
None
,
post_patch_height
=
None
,
post_patch_width
=
None
,
):
batch_size
=
hidden_states
.
shape
[
0
]
if
self
.
residual_diff_threshold
<=
0.0
or
batch_size
>
2
:
if
batch_size
>
2
:
print
(
"Batch size > 2 (for SANA CFG)"
" currently not supported"
)
first_transformer_block
=
self
.
transformer_blocks
[
0
]
hidden_states
=
first_transformer_block
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
timestep
=
timestep
,
height
=
post_patch_height
,
width
=
post_patch_width
,
skip_first_layer
=
False
,
)
return
hidden_states
original_hidden_states
=
hidden_states
first_transformer_block
=
self
.
transformer_blocks
[
0
]
hidden_states
=
first_transformer_block
.
forward_layer_at
(
0
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
timestep
=
timestep
,
height
=
post_patch_height
,
width
=
post_patch_width
,
)
first_hidden_states_residual
=
hidden_states
-
original_hidden_states
del
original_hidden_states
can_use_cache
=
get_can_use_cache
(
first_hidden_states_residual
,
threshold
=
self
.
residual_diff_threshold
,
parallelized
=
self
.
transformer
is
not
None
and
getattr
(
self
.
transformer
,
"_is_parallelized"
,
False
),
)
torch
.
_dynamo
.
graph_break
()
if
can_use_cache
:
del
first_hidden_states_residual
if
self
.
verbose
:
print
(
"Cache hit!!!"
)
hidden_states
,
_
=
apply_prev_hidden_states_residual
(
hidden_states
,
None
)
else
:
if
self
.
verbose
:
print
(
"Cache miss!!!"
)
set_buffer
(
"first_hidden_states_residual"
,
first_hidden_states_residual
)
del
first_hidden_states_residual
hidden_states
,
hidden_states_residual
=
self
.
call_remaining_transformer_blocks
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
timestep
=
timestep
,
post_patch_height
=
post_patch_height
,
post_patch_width
=
post_patch_width
,
)
set_buffer
(
"hidden_states_residual"
,
hidden_states_residual
)
torch
.
_dynamo
.
graph_break
()
return
hidden_states
def
call_remaining_transformer_blocks
(
self
,
hidden_states
,
attention_mask
,
encoder_hidden_states
,
encoder_attention_mask
=
None
,
timestep
=
None
,
post_patch_height
=
None
,
post_patch_width
=
None
):
first_transformer_block
=
self
.
transformer_blocks
[
0
]
original_hidden_states
=
hidden_states
hidden_states
=
first_transformer_block
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
timestep
=
timestep
,
height
=
post_patch_height
,
width
=
post_patch_width
,
skip_first_layer
=
True
)
hidden_states_residual
=
hidden_states
-
original_hidden_states
return
hidden_states
,
hidden_states_residual
class
CachedTransformerBlocks
(
nn
.
Module
):
class
Flux
CachedTransformerBlocks
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
...
...
nunchaku/csrc/sana.h
View file @
bf0813a6
...
@@ -34,7 +34,8 @@ public:
...
@@ -34,7 +34,8 @@ public:
int
H
,
int
H
,
int
W
,
int
W
,
bool
pag
,
bool
pag
,
bool
cfg
)
bool
cfg
,
bool
skip_first_layer
=
false
)
{
{
checkModel
();
checkModel
();
CUDADeviceContext
ctx
(
deviceId
);
CUDADeviceContext
ctx
(
deviceId
);
...
@@ -54,7 +55,8 @@ public:
...
@@ -54,7 +55,8 @@ public:
from_torch
(
cu_seqlens_img
),
from_torch
(
cu_seqlens_img
),
from_torch
(
cu_seqlens_txt
),
from_torch
(
cu_seqlens_txt
),
H
,
W
,
H
,
W
,
pag
,
cfg
pag
,
cfg
,
skip_first_layer
);
);
torch
::
Tensor
output
=
to_torch
(
result
);
torch
::
Tensor
output
=
to_torch
(
result
);
...
...
nunchaku/models/transformers/transformer_sana.py
View file @
bf0813a6
...
@@ -30,6 +30,7 @@ class NunchakuSanaTransformerBlocks(nn.Module):
...
@@ -30,6 +30,7 @@ class NunchakuSanaTransformerBlocks(nn.Module):
timestep
:
Optional
[
torch
.
LongTensor
]
=
None
,
timestep
:
Optional
[
torch
.
LongTensor
]
=
None
,
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
skip_first_layer
:
Optional
[
bool
]
=
False
):
):
batch_size
=
hidden_states
.
shape
[
0
]
batch_size
=
hidden_states
.
shape
[
0
]
...
@@ -69,6 +70,61 @@ class NunchakuSanaTransformerBlocks(nn.Module):
...
@@ -69,6 +70,61 @@ class NunchakuSanaTransformerBlocks(nn.Module):
width
,
width
,
batch_size
%
3
==
0
,
# pag is set when loading the model, FIXME: pag_scale == 0
batch_size
%
3
==
0
,
# pag is set when loading the model, FIXME: pag_scale == 0
True
,
# TODO: find a way to detect if we are doing CFG
True
,
# TODO: find a way to detect if we are doing CFG
skip_first_layer
,
)
.
to
(
original_dtype
)
.
to
(
original_device
)
)
def
forward_layer_at
(
self
,
idx
:
int
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
timestep
:
Optional
[
torch
.
LongTensor
]
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
):
batch_size
=
hidden_states
.
shape
[
0
]
img_tokens
=
hidden_states
.
shape
[
1
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
original_dtype
=
hidden_states
.
dtype
original_device
=
hidden_states
.
device
assert
encoder_attention_mask
is
not
None
assert
encoder_attention_mask
.
shape
==
(
batch_size
,
1
,
txt_tokens
)
mask
=
encoder_attention_mask
.
reshape
(
batch_size
,
txt_tokens
)
nunchaku_encoder_hidden_states
=
encoder_hidden_states
[
mask
>
-
9000
]
cu_seqlens_txt
=
F
.
pad
((
mask
>
-
9000
).
sum
(
dim
=
1
).
cumsum
(
dim
=
0
),
pad
=
(
1
,
0
),
value
=
0
).
to
(
torch
.
int32
)
cu_seqlens_img
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
img_tokens
,
img_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
if
height
is
None
and
width
is
None
:
height
=
width
=
int
(
img_tokens
**
0.5
)
elif
height
is
None
:
height
=
img_tokens
//
width
elif
width
is
None
:
width
=
img_tokens
//
height
assert
height
*
width
==
img_tokens
return
(
self
.
m
.
forward_layer
(
idx
,
hidden_states
.
to
(
self
.
dtype
).
to
(
self
.
device
),
nunchaku_encoder_hidden_states
.
to
(
self
.
dtype
).
to
(
self
.
device
),
timestep
.
to
(
self
.
dtype
).
to
(
self
.
device
),
cu_seqlens_img
.
to
(
self
.
device
),
cu_seqlens_txt
.
to
(
self
.
device
),
height
,
width
,
batch_size
%
3
==
0
,
# pag is set when loading the model, FIXME: pag_scale == 0
True
,
# TODO: find a way to detect if we are doing CFG
)
)
.
to
(
original_dtype
)
.
to
(
original_dtype
)
.
to
(
original_device
)
.
to
(
original_device
)
...
...
src/SanaModel.cpp
View file @
bf0813a6
#include <iostream>
#include "SanaModel.h"
#include "SanaModel.h"
#include "kernels/zgemm/zgemm.h"
#include "kernels/zgemm/zgemm.h"
#include "flash_api.h"
#include "flash_api.h"
...
@@ -8,6 +10,7 @@
...
@@ -8,6 +10,7 @@
using
spdlog
::
fmt_lib
::
format
;
using
spdlog
::
fmt_lib
::
format
;
using
namespace
nunchaku
;
using
namespace
nunchaku
;
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim
(
dim
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
...
@@ -324,8 +327,8 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
...
@@ -324,8 +327,8 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
}
}
}
}
Tensor
SanaModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
)
{
Tensor
SanaModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
)
{
for
(
int
i
=
0
;
i
<
config
.
num_layers
;
i
++
)
{
for
(
int
i
=
(
skip_first_layer
?
1
:
0
)
;
i
<
config
.
num_layers
;
i
++
)
{
auto
&&
block
=
transformer_blocks
[
i
];
auto
&&
block
=
transformer_blocks
[
i
];
hidden_states
=
block
->
forward
(
hidden_states
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
timestep
,
cu_seqlens_img
,
cu_seqlens_txt
,
H
,
W
,
hidden_states
,
encoder_hidden_states
,
timestep
,
cu_seqlens_img
,
cu_seqlens_txt
,
H
,
W
,
...
...
src/SanaModel.h
View file @
bf0813a6
...
@@ -89,7 +89,7 @@ struct SanaConfig {
...
@@ -89,7 +89,7 @@ struct SanaConfig {
class
SanaModel
:
public
Module
{
class
SanaModel
:
public
Module
{
public:
public:
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
);
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
);
public:
public:
const
SanaConfig
config
;
const
SanaConfig
config
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment