Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
bf0813a6
"vscode:/vscode.git/clone" did not exist on "1ddb70f9e5faaa70bfa333a053ec6b1ccc83c311"
Commit
bf0813a6
authored
Mar 18, 2025
by
Hyunsung Lee
Committed by
Zhekai Zhang
Apr 01, 2025
Browse files
Add SanaModel caching
parent
65d7e47a
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
313 additions
and
51 deletions
+313
-51
examples/int4-sana_1600m-cache.py
examples/int4-sana_1600m-cache.py
+32
-0
nunchaku/caching/diffusers_adapters/__init__.py
nunchaku/caching/diffusers_adapters/__init__.py
+2
-0
nunchaku/caching/diffusers_adapters/flux.py
nunchaku/caching/diffusers_adapters/flux.py
+1
-1
nunchaku/caching/diffusers_adapters/sana.py
nunchaku/caching/diffusers_adapters/sana.py
+50
-0
nunchaku/caching/utils.py
nunchaku/caching/utils.py
+125
-8
nunchaku/csrc/sana.h
nunchaku/csrc/sana.h
+19
-17
nunchaku/models/transformers/transformer_sana.py
nunchaku/models/transformers/transformer_sana.py
+56
-0
src/SanaModel.cpp
src/SanaModel.cpp
+27
-24
src/SanaModel.h
src/SanaModel.h
+1
-1
No files found.
examples/int4-sana_1600m-cache.py
0 → 100644
View file @
bf0813a6
import
torch
from
diffusers
import
SanaPipeline
from
nunchaku
import
NunchakuSanaTransformer2DModel
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
transformer
=
NunchakuSanaTransformer2DModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-sana-1600m"
)
pipe
=
SanaPipeline
.
from_pretrained
(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
,
transformer
=
transformer
,
variant
=
"bf16"
,
torch_dtype
=
torch
.
bfloat16
,
).
to
(
"cuda"
)
pipe
.
vae
.
to
(
torch
.
bfloat16
)
pipe
.
text_encoder
.
to
(
torch
.
bfloat16
)
apply_cache_on_pipe
(
pipe
,
residual_diff_threshold
=
0.25
)
# WarmUp
prompt
=
"A cute 🐼 eating 🎋, ink drawing style"
image
=
pipe
(
prompt
=
prompt
,
height
=
1024
,
width
=
1024
,
guidance_scale
=
4.5
,
num_inference_steps
=
20
,
generator
=
torch
.
Generator
().
manual_seed
(
42
),
).
images
[
0
]
image
.
save
(
"sana_1600m.png"
)
nunchaku/caching/diffusers_adapters/__init__.py
View file @
bf0813a6
...
@@ -7,6 +7,8 @@ def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
...
@@ -7,6 +7,8 @@ def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
pipe_cls_name
=
pipe
.
__class__
.
__name__
pipe_cls_name
=
pipe
.
__class__
.
__name__
if
pipe_cls_name
.
startswith
(
"Flux"
):
if
pipe_cls_name
.
startswith
(
"Flux"
):
from
.flux
import
apply_cache_on_pipe
as
apply_cache_on_pipe_fn
from
.flux
import
apply_cache_on_pipe
as
apply_cache_on_pipe_fn
elif
pipe_cls_name
.
startswith
(
"Sana"
):
from
.sana
import
apply_cache_on_pipe
as
apply_cache_on_pipe_fn
else
:
else
:
raise
ValueError
(
f
"Unknown pipeline class name:
{
pipe_cls_name
}
"
)
raise
ValueError
(
f
"Unknown pipeline class name:
{
pipe_cls_name
}
"
)
return
apply_cache_on_pipe_fn
(
pipe
,
*
args
,
**
kwargs
)
return
apply_cache_on_pipe_fn
(
pipe
,
*
args
,
**
kwargs
)
nunchaku/caching/diffusers_adapters/flux.py
View file @
bf0813a6
...
@@ -13,7 +13,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_
...
@@ -13,7 +13,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_
cached_transformer_blocks
=
torch
.
nn
.
ModuleList
(
cached_transformer_blocks
=
torch
.
nn
.
ModuleList
(
[
[
utils
.
CachedTransformerBlocks
(
utils
.
Flux
CachedTransformerBlocks
(
transformer
=
transformer
,
transformer
=
transformer
,
residual_diff_threshold
=
residual_diff_threshold
,
residual_diff_threshold
=
residual_diff_threshold
,
return_hidden_states_first
=
False
,
return_hidden_states_first
=
False
,
...
...
nunchaku/caching/diffusers_adapters/sana.py
0 → 100644
View file @
bf0813a6
import
functools
import
unittest
import
torch
from
diffusers
import
DiffusionPipeline
,
SanaTransformer2DModel
from
...caching
import
utils
def
apply_cache_on_transformer
(
transformer
:
SanaTransformer2DModel
,
*
,
residual_diff_threshold
=
0.12
):
if
getattr
(
transformer
,
"_is_cached"
,
False
):
return
transformer
cached_transformer_blocks
=
torch
.
nn
.
ModuleList
(
[
utils
.
SanaCachedTransformerBlocks
(
transformer
=
transformer
,
residual_diff_threshold
=
residual_diff_threshold
,
)
]
)
original_forward
=
transformer
.
forward
@
functools
.
wraps
(
original_forward
)
def
new_forward
(
self
,
*
args
,
**
kwargs
):
with
unittest
.
mock
.
patch
.
object
(
self
,
"transformer_blocks"
,
cached_transformer_blocks
):
return
original_forward
(
*
args
,
**
kwargs
)
transformer
.
forward
=
new_forward
.
__get__
(
transformer
)
transformer
.
_is_cached
=
True
return
transformer
def
apply_cache_on_pipe
(
pipe
:
DiffusionPipeline
,
*
,
shallow_patch
:
bool
=
False
,
**
kwargs
):
if
not
getattr
(
pipe
,
"_is_cached"
,
False
):
original_call
=
pipe
.
__class__
.
__call__
@
functools
.
wraps
(
original_call
)
def
new_call
(
self
,
*
args
,
**
kwargs
):
with
utils
.
cache_context
(
utils
.
create_cache_context
()):
return
original_call
(
self
,
*
args
,
**
kwargs
)
pipe
.
__class__
.
__call__
=
new_call
pipe
.
__class__
.
_is_cached
=
True
if
not
shallow_patch
:
apply_cache_on_transformer
(
pipe
.
transformer
,
**
kwargs
)
return
pipe
nunchaku/caching/utils.py
View file @
bf0813a6
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
contextlib
import
contextlib
import
dataclasses
import
dataclasses
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
DefaultDict
,
Dict
from
typing
import
DefaultDict
,
Dict
,
Optional
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -81,18 +81,19 @@ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
...
@@ -81,18 +81,19 @@ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
@
torch
.
compiler
.
disable
@
torch
.
compiler
.
disable
def
apply_prev_hidden_states_residual
(
def
apply_prev_hidden_states_residual
(
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_states_residual
=
get_buffer
(
"hidden_states_residual"
)
hidden_states_residual
=
get_buffer
(
"hidden_states_residual"
)
assert
hidden_states_residual
is
not
None
,
"hidden_states_residual must be set before"
assert
hidden_states_residual
is
not
None
,
"hidden_states_residual must be set before"
hidden_states
=
hidden_states_residual
+
hidden_states
hidden_states
=
hidden_states_residual
+
hidden_states
encoder_hidden_states_residual
=
get_buffer
(
"encoder_hidden_states_residual"
)
assert
encoder_hidden_states_residual
is
not
None
,
"encoder_hidden_states_residual must be set before"
encoder_hidden_states
=
encoder_hidden_states_residual
+
encoder_hidden_states
hidden_states
=
hidden_states
.
contiguous
()
hidden_states
=
hidden_states
.
contiguous
()
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
if
encoder_hidden_states
is
not
None
:
encoder_hidden_states_residual
=
get_buffer
(
"encoder_hidden_states_residual"
)
assert
encoder_hidden_states_residual
is
not
None
,
"encoder_hidden_states_residual must be set before"
encoder_hidden_states
=
encoder_hidden_states_residual
+
encoder_hidden_states
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
return
hidden_states
,
encoder_hidden_states
return
hidden_states
,
encoder_hidden_states
...
@@ -108,8 +109,124 @@ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=Fals
...
@@ -108,8 +109,124 @@ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=Fals
)
)
return
can_use_cache
return
can_use_cache
class
SanaCachedTransformerBlocks
(
nn
.
Module
):
def
__init__
(
self
,
*
,
transformer
=
None
,
residual_diff_threshold
,
verbose
:
bool
=
False
,
):
super
().
__init__
()
self
.
transformer
=
transformer
self
.
transformer_blocks
=
transformer
.
transformer_blocks
self
.
residual_diff_threshold
=
residual_diff_threshold
self
.
verbose
=
verbose
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_hidden_states
,
encoder_attention_mask
=
None
,
timestep
=
None
,
post_patch_height
=
None
,
post_patch_width
=
None
,
):
batch_size
=
hidden_states
.
shape
[
0
]
if
self
.
residual_diff_threshold
<=
0.0
or
batch_size
>
2
:
if
batch_size
>
2
:
print
(
"Batch size > 2 (for SANA CFG)"
" currently not supported"
)
first_transformer_block
=
self
.
transformer_blocks
[
0
]
hidden_states
=
first_transformer_block
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
timestep
=
timestep
,
height
=
post_patch_height
,
width
=
post_patch_width
,
skip_first_layer
=
False
,
)
return
hidden_states
original_hidden_states
=
hidden_states
first_transformer_block
=
self
.
transformer_blocks
[
0
]
hidden_states
=
first_transformer_block
.
forward_layer_at
(
0
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
timestep
=
timestep
,
height
=
post_patch_height
,
width
=
post_patch_width
,
)
first_hidden_states_residual
=
hidden_states
-
original_hidden_states
del
original_hidden_states
can_use_cache
=
get_can_use_cache
(
first_hidden_states_residual
,
threshold
=
self
.
residual_diff_threshold
,
parallelized
=
self
.
transformer
is
not
None
and
getattr
(
self
.
transformer
,
"_is_parallelized"
,
False
),
)
torch
.
_dynamo
.
graph_break
()
if
can_use_cache
:
del
first_hidden_states_residual
if
self
.
verbose
:
print
(
"Cache hit!!!"
)
hidden_states
,
_
=
apply_prev_hidden_states_residual
(
hidden_states
,
None
)
else
:
if
self
.
verbose
:
print
(
"Cache miss!!!"
)
set_buffer
(
"first_hidden_states_residual"
,
first_hidden_states_residual
)
del
first_hidden_states_residual
hidden_states
,
hidden_states_residual
=
self
.
call_remaining_transformer_blocks
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
timestep
=
timestep
,
post_patch_height
=
post_patch_height
,
post_patch_width
=
post_patch_width
,
)
set_buffer
(
"hidden_states_residual"
,
hidden_states_residual
)
torch
.
_dynamo
.
graph_break
()
return
hidden_states
def
call_remaining_transformer_blocks
(
self
,
hidden_states
,
attention_mask
,
encoder_hidden_states
,
encoder_attention_mask
=
None
,
timestep
=
None
,
post_patch_height
=
None
,
post_patch_width
=
None
):
first_transformer_block
=
self
.
transformer_blocks
[
0
]
original_hidden_states
=
hidden_states
hidden_states
=
first_transformer_block
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
timestep
=
timestep
,
height
=
post_patch_height
,
width
=
post_patch_width
,
skip_first_layer
=
True
)
hidden_states_residual
=
hidden_states
-
original_hidden_states
return
hidden_states
,
hidden_states_residual
class
CachedTransformerBlocks
(
nn
.
Module
):
class
Flux
CachedTransformerBlocks
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
...
...
nunchaku/csrc/sana.h
View file @
bf0813a6
...
@@ -26,15 +26,16 @@ public:
...
@@ -26,15 +26,16 @@ public:
}
}
torch
::
Tensor
forward
(
torch
::
Tensor
forward
(
torch
::
Tensor
hidden_states
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
timestep
,
torch
::
Tensor
timestep
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_txt
,
torch
::
Tensor
cu_seqlens_txt
,
int
H
,
int
H
,
int
W
,
int
W
,
bool
pag
,
bool
pag
,
bool
cfg
)
bool
cfg
,
bool
skip_first_layer
=
false
)
{
{
checkModel
();
checkModel
();
CUDADeviceContext
ctx
(
deviceId
);
CUDADeviceContext
ctx
(
deviceId
);
...
@@ -54,7 +55,8 @@ public:
...
@@ -54,7 +55,8 @@ public:
from_torch
(
cu_seqlens_img
),
from_torch
(
cu_seqlens_img
),
from_torch
(
cu_seqlens_txt
),
from_torch
(
cu_seqlens_txt
),
H
,
W
,
H
,
W
,
pag
,
cfg
pag
,
cfg
,
skip_first_layer
);
);
torch
::
Tensor
output
=
to_torch
(
result
);
torch
::
Tensor
output
=
to_torch
(
result
);
...
@@ -65,15 +67,15 @@ public:
...
@@ -65,15 +67,15 @@ public:
torch
::
Tensor
forward_layer
(
torch
::
Tensor
forward_layer
(
int64_t
idx
,
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
timestep
,
torch
::
Tensor
timestep
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_txt
,
torch
::
Tensor
cu_seqlens_txt
,
int
H
,
int
H
,
int
W
,
int
W
,
bool
pag
,
bool
pag
,
bool
cfg
)
bool
cfg
)
{
{
checkModel
();
checkModel
();
CUDADeviceContext
ctx
(
deviceId
);
CUDADeviceContext
ctx
(
deviceId
);
...
...
nunchaku/models/transformers/transformer_sana.py
View file @
bf0813a6
...
@@ -30,6 +30,7 @@ class NunchakuSanaTransformerBlocks(nn.Module):
...
@@ -30,6 +30,7 @@ class NunchakuSanaTransformerBlocks(nn.Module):
timestep
:
Optional
[
torch
.
LongTensor
]
=
None
,
timestep
:
Optional
[
torch
.
LongTensor
]
=
None
,
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
skip_first_layer
:
Optional
[
bool
]
=
False
):
):
batch_size
=
hidden_states
.
shape
[
0
]
batch_size
=
hidden_states
.
shape
[
0
]
...
@@ -69,6 +70,61 @@ class NunchakuSanaTransformerBlocks(nn.Module):
...
@@ -69,6 +70,61 @@ class NunchakuSanaTransformerBlocks(nn.Module):
width
,
width
,
batch_size
%
3
==
0
,
# pag is set when loading the model, FIXME: pag_scale == 0
batch_size
%
3
==
0
,
# pag is set when loading the model, FIXME: pag_scale == 0
True
,
# TODO: find a way to detect if we are doing CFG
True
,
# TODO: find a way to detect if we are doing CFG
skip_first_layer
,
)
.
to
(
original_dtype
)
.
to
(
original_device
)
)
def
forward_layer_at
(
self
,
idx
:
int
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
timestep
:
Optional
[
torch
.
LongTensor
]
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
):
batch_size
=
hidden_states
.
shape
[
0
]
img_tokens
=
hidden_states
.
shape
[
1
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
original_dtype
=
hidden_states
.
dtype
original_device
=
hidden_states
.
device
assert
encoder_attention_mask
is
not
None
assert
encoder_attention_mask
.
shape
==
(
batch_size
,
1
,
txt_tokens
)
mask
=
encoder_attention_mask
.
reshape
(
batch_size
,
txt_tokens
)
nunchaku_encoder_hidden_states
=
encoder_hidden_states
[
mask
>
-
9000
]
cu_seqlens_txt
=
F
.
pad
((
mask
>
-
9000
).
sum
(
dim
=
1
).
cumsum
(
dim
=
0
),
pad
=
(
1
,
0
),
value
=
0
).
to
(
torch
.
int32
)
cu_seqlens_img
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
img_tokens
,
img_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
if
height
is
None
and
width
is
None
:
height
=
width
=
int
(
img_tokens
**
0.5
)
elif
height
is
None
:
height
=
img_tokens
//
width
elif
width
is
None
:
width
=
img_tokens
//
height
assert
height
*
width
==
img_tokens
return
(
self
.
m
.
forward_layer
(
idx
,
hidden_states
.
to
(
self
.
dtype
).
to
(
self
.
device
),
nunchaku_encoder_hidden_states
.
to
(
self
.
dtype
).
to
(
self
.
device
),
timestep
.
to
(
self
.
dtype
).
to
(
self
.
device
),
cu_seqlens_img
.
to
(
self
.
device
),
cu_seqlens_txt
.
to
(
self
.
device
),
height
,
width
,
batch_size
%
3
==
0
,
# pag is set when loading the model, FIXME: pag_scale == 0
True
,
# TODO: find a way to detect if we are doing CFG
)
)
.
to
(
original_dtype
)
.
to
(
original_dtype
)
.
to
(
original_device
)
.
to
(
original_device
)
...
...
src/SanaModel.cpp
View file @
bf0813a6
#include <iostream>
#include "SanaModel.h"
#include "SanaModel.h"
#include "kernels/zgemm/zgemm.h"
#include "kernels/zgemm/zgemm.h"
#include "flash_api.h"
#include "flash_api.h"
...
@@ -8,6 +10,7 @@
...
@@ -8,6 +10,7 @@
using
spdlog
::
fmt_lib
::
format
;
using
spdlog
::
fmt_lib
::
format
;
using
namespace
nunchaku
;
using
namespace
nunchaku
;
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim
(
dim
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
...
@@ -28,7 +31,7 @@ SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_
...
@@ -28,7 +31,7 @@ SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_
Tensor
SanaLinearAttention
::
forward
(
Tensor
x
,
Tensor
out
)
{
Tensor
SanaLinearAttention
::
forward
(
Tensor
x
,
Tensor
out
)
{
constexpr
int
HEAD_DIM
=
32
;
constexpr
int
HEAD_DIM
=
32
;
assert
(
x
.
ndims
()
==
3
);
assert
(
x
.
ndims
()
==
3
);
const
int
batch_size
=
x
.
shape
[
0
];
const
int
batch_size
=
x
.
shape
[
0
];
const
int
num_tokens
=
x
.
shape
[
1
];
const
int
num_tokens
=
x
.
shape
[
1
];
...
@@ -45,7 +48,7 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
...
@@ -45,7 +48,7 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
x_pad
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens
).
copy_
(
x
.
slice
(
0
,
i
,
i
+
1
));
x_pad
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens
).
copy_
(
x
.
slice
(
0
,
i
,
i
+
1
));
}
}
x
=
x_pad
;
x
=
x_pad
;
}
}
...
@@ -55,14 +58,14 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
...
@@ -55,14 +58,14 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
Tensor
vk
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
HEAD_DIM
+
1
,
HEAD_DIM
},
Tensor
::
FP32
,
x
.
device
());
Tensor
vk
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
HEAD_DIM
+
1
,
HEAD_DIM
},
Tensor
::
FP32
,
x
.
device
());
kernels
::
gemm_w4a4
(
kernels
::
gemm_w4a4
(
qact
.
act
,
qact
.
act
,
qkv_proj
.
qweight
,
qkv_proj
.
qweight
,
{},
{},
{},
{},
qact
.
ascales
,
qact
.
ascales
,
qkv_proj
.
wscales
,
qkv_proj
.
wscales
,
{},
{},
qact
.
lora_act
,
qkv_proj
.
lora_up
,
{},
{},
{},
{},
{},
qkv_proj
.
bias
,
{},
{},
{},
qact
.
lora_act
,
qkv_proj
.
lora_up
,
{},
{},
{},
{},
{},
qkv_proj
.
bias
,
{},
vk
,
q
,
vk
,
q
,
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
,
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
,
qkv_proj
.
use_fp4
,
qkv_proj
.
use_fp4
,
*
qkv_proj
.
wtscale
.
data_ptr
<
float
>
(),
*
qkv_proj
.
wtscale
.
data_ptr
<
float
>
(),
...
@@ -118,12 +121,12 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
...
@@ -118,12 +121,12 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
}
}
this
->
forward
(
x_org
,
out_org
);
this
->
forward
(
x_org
,
out_org
);
Tensor
v_ptb
=
this
->
pag_to_v
.
value
().
forward
(
x_ptb
);
Tensor
v_ptb
=
this
->
pag_to_v
.
value
().
forward
(
x_ptb
);
this
->
out_proj
.
forward
(
v_ptb
,
out_ptb
);
this
->
out_proj
.
forward
(
v_ptb
,
out_ptb
);
return
out
;
return
out
;
}
}
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
num_heads
(
num_heads
),
head_dim
(
head_dim
),
num_heads
(
num_heads
),
head_dim
(
head_dim
),
...
@@ -143,7 +146,7 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
...
@@ -143,7 +146,7 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
assert
(
cond
.
ndims
()
==
2
);
assert
(
cond
.
ndims
()
==
2
);
assert
(
cu_seqlens_img
.
ndims
()
==
1
);
assert
(
cu_seqlens_img
.
ndims
()
==
1
);
assert
(
cu_seqlens_txt
.
ndims
()
==
1
);
assert
(
cu_seqlens_txt
.
ndims
()
==
1
);
const
int
batch_size
=
x
.
shape
[
0
];
const
int
batch_size
=
x
.
shape
[
0
];
const
int
num_tokens_img
=
x
.
shape
[
1
];
const
int
num_tokens_img
=
x
.
shape
[
1
];
const
int
num_tokens_txt
=
cond
.
shape
[
0
];
const
int
num_tokens_txt
=
cond
.
shape
[
0
];
...
@@ -163,21 +166,21 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
...
@@ -163,21 +166,21 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
num_tokens_img
,
num_tokens_txt
,
num_tokens_img
,
num_tokens_txt
,
0.0
f
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
false
,
false
,
false
,
-
1
,
-
1
,
-
1
,
-
1
,
false
false
).
front
().
view
({
batch_size
,
num_tokens_img
,
num_heads
*
head_dim
});
).
front
().
view
({
batch_size
,
num_tokens_img
,
num_heads
*
head_dim
});
// Tensor attn_output = mha_fwd(q, k, v,
// Tensor attn_output = mha_fwd(q, k, v,
// 0.0f,
// 0.0f,
// pow(q.shape[-1], (-0.5)),
// pow(q.shape[-1], (-0.5)),
// false, -1, -1, false
// false, -1, -1, false
// ).front().view({B, N, num_heads * head_dim});
// ).front().view({B, N, num_heads * head_dim});
return
out_proj
.
forward
(
attn_output
);
return
out_proj
.
forward
(
attn_output
);
}
}
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
hidden_features
(
hidden_features
),
in_features
(
in_features
),
hidden_features
(
hidden_features
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
use_fp4
,
dtype
,
device
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
use_fp4
,
dtype
,
device
),
depth_conv
(
hidden_features
*
2
,
true
,
dtype
,
device
),
depth_conv
(
hidden_features
*
2
,
true
,
dtype
,
device
),
...
@@ -204,7 +207,7 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
...
@@ -204,7 +207,7 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
return
point_conv
.
forward_quant
(
qact
);
return
point_conv
.
forward_quant
(
qact
);
}
}
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
hidden_size
(
hidden_size
),
num_cross_attention_heads
(
num_cross_attention_heads
),
hidden_size
(
hidden_size
),
num_cross_attention_heads
(
num_cross_attention_heads
),
attn
(
hidden_size
,
false
,
pag
,
use_fp4
,
dtype
,
device
),
attn
(
hidden_size
,
false
,
pag
,
use_fp4
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
use_fp4
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
use_fp4
,
dtype
,
device
),
...
@@ -240,7 +243,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
...
@@ -240,7 +243,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
kernels
::
mul_add_batch
(
timestep
,
{},
false
,
0
,
this
->
scale_shift_table
,
false
);
kernels
::
mul_add_batch
(
timestep
,
{},
false
,
0
,
this
->
scale_shift_table
,
false
);
debug
(
"shifted_timestep"
,
timestep
);
debug
(
"shifted_timestep"
,
timestep
);
std
::
array
<
Tensor
,
6
>
chunked
;
std
::
array
<
Tensor
,
6
>
chunked
;
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
chunked
[
i
]
=
timestep
.
slice
(
1
,
i
,
i
+
1
);
chunked
[
i
]
=
timestep
.
slice
(
1
,
i
,
i
+
1
);
...
@@ -299,7 +302,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
...
@@ -299,7 +302,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
nvtxRangePop
();
nvtxRangePop
();
}
}
nvtxRangePop
();
nvtxRangePop
();
debug
(
"hidden_states_out"
,
hidden_states
);
debug
(
"hidden_states_out"
,
hidden_states
);
...
@@ -307,7 +310,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
...
@@ -307,7 +310,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
return
hidden_states
;
return
hidden_states
;
}
}
SanaModel
::
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaModel
::
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
config
(
config
)
config
(
config
)
{
{
const
int
inner_dim
=
config
.
num_attention_heads
*
config
.
attention_head_dim
;
const
int
inner_dim
=
config
.
num_attention_heads
*
config
.
attention_head_dim
;
...
@@ -324,8 +327,8 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
...
@@ -324,8 +327,8 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
}
}
}
}
Tensor
SanaModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
)
{
Tensor
SanaModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
)
{
for
(
int
i
=
0
;
i
<
config
.
num_layers
;
i
++
)
{
for
(
int
i
=
(
skip_first_layer
?
1
:
0
)
;
i
<
config
.
num_layers
;
i
++
)
{
auto
&&
block
=
transformer_blocks
[
i
];
auto
&&
block
=
transformer_blocks
[
i
];
hidden_states
=
block
->
forward
(
hidden_states
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
timestep
,
cu_seqlens_img
,
cu_seqlens_txt
,
H
,
W
,
hidden_states
,
encoder_hidden_states
,
timestep
,
cu_seqlens_img
,
cu_seqlens_txt
,
H
,
W
,
...
...
src/SanaModel.h
View file @
bf0813a6
...
@@ -89,7 +89,7 @@ struct SanaConfig {
...
@@ -89,7 +89,7 @@ struct SanaConfig {
class
SanaModel
:
public
Module
{
class
SanaModel
:
public
Module
{
public:
public:
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
);
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
);
public:
public:
const
SanaConfig
config
;
const
SanaConfig
config
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment