Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
ad8097b9
Unverified
Commit
ad8097b9
authored
Apr 04, 2025
by
Muyang Li
Committed by
GitHub
Apr 04, 2025
Browse files
Release v0.2.0
Ready to release v0.2.0
parents
804a6d30
998192ca
Changes
142
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
714 additions
and
122 deletions
+714
-122
examples/fp4-flux.1-schnell.py
examples/fp4-flux.1-schnell.py
+0
-13
examples/int4-flux.1-schnell-offload.py
examples/int4-flux.1-schnell-offload.py
+0
-16
examples/int4-flux.1-schnell-qencoder-offload.py
examples/int4-flux.1-schnell-qencoder-offload.py
+0
-20
examples/int4-flux.1-schnell-qencoder.py
examples/int4-flux.1-schnell-qencoder.py
+0
-17
examples/sana_1600m-cache.py
examples/sana_1600m-cache.py
+31
-0
examples/sana_1600m.py
examples/sana_1600m.py
+1
-1
examples/sana_1600m_pag.py
examples/sana_1600m_pag.py
+1
-1
nunchaku/__version__.py
nunchaku/__version__.py
+1
-1
nunchaku/caching/__init__.py
nunchaku/caching/__init__.py
+0
-0
nunchaku/caching/diffusers_adapters/__init__.py
nunchaku/caching/diffusers_adapters/__init__.py
+14
-0
nunchaku/caching/diffusers_adapters/flux.py
nunchaku/caching/diffusers_adapters/flux.py
+56
-0
nunchaku/caching/diffusers_adapters/sana.py
nunchaku/caching/diffusers_adapters/sana.py
+50
-0
nunchaku/caching/utils.py
nunchaku/caching/utils.py
+335
-0
nunchaku/csrc/flux.h
nunchaku/csrc/flux.h
+61
-21
nunchaku/csrc/gemm.h
nunchaku/csrc/gemm.h
+7
-7
nunchaku/csrc/module.h
nunchaku/csrc/module.h
+23
-0
nunchaku/csrc/ops.h
nunchaku/csrc/ops.h
+56
-2
nunchaku/csrc/pybind.cpp
nunchaku/csrc/pybind.cpp
+40
-5
nunchaku/csrc/sana.h
nunchaku/csrc/sana.h
+25
-18
nunchaku/csrc/utils.h
nunchaku/csrc/utils.h
+13
-0
No files found.
examples/fp4-flux.1-schnell.py
deleted
100644 → 0
View file @
804a6d30
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-fp4-flux.1-schnell"
,
precision
=
"fp4"
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
image
=
pipeline
(
"A cat holding a sign that says hello world"
,
width
=
1024
,
height
=
1024
,
num_inference_steps
=
4
,
guidance_scale
=
0
).
images
[
0
]
image
.
save
(
"flux.1-schnell.png"
)
examples/int4-flux.1-schnell-offload.py
deleted
100644 → 0
View file @
804a6d30
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-schnell"
,
offload
=
True
)
# set offload to False if you want to disable offloading
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
)
pipeline
.
enable_sequential_cpu_offload
()
# remove this line if you want to disable the CPU offloading
image
=
pipeline
(
"A cat holding a sign that says hello world"
,
width
=
1024
,
height
=
1024
,
num_inference_steps
=
4
,
guidance_scale
=
0
).
images
[
0
]
image
.
save
(
"flux.1-schnell.png"
)
examples/int4-flux.1-schnell-qencoder-offload.py
deleted
100644 → 0
View file @
804a6d30
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-schnell"
,
offload
=
True
)
# set offload to False if you want to disable offloading
text_encoder_2
=
NunchakuT5EncoderModel
.
from_pretrained
(
"mit-han-lab/svdq-flux.1-t5"
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
text_encoder_2
=
text_encoder_2
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
,
).
to
(
"cuda"
)
pipeline
.
enable_sequential_cpu_offload
()
# remove this line if you want to disable the CPU offloading
image
=
pipeline
(
"A cat holding a sign that says hello world"
,
width
=
1024
,
height
=
1024
,
num_inference_steps
=
4
,
guidance_scale
=
0
).
images
[
0
]
image
.
save
(
"flux.1-schnell.png"
)
examples/int4-flux.1-schnell-qencoder.py
deleted
100644 → 0
View file @
804a6d30
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-schnell"
)
text_encoder_2
=
NunchakuT5EncoderModel
.
from_pretrained
(
"mit-han-lab/svdq-flux.1-t5"
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
text_encoder_2
=
text_encoder_2
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
,
).
to
(
"cuda"
)
image
=
pipeline
(
"A cat holding a sign that says hello world"
,
width
=
1024
,
height
=
1024
,
num_inference_steps
=
4
,
guidance_scale
=
0
).
images
[
0
]
image
.
save
(
"flux.1-schnell.png"
)
examples/sana_1600m-cache.py
0 → 100644
View file @
ad8097b9
import
torch
from
diffusers
import
SanaPipeline
from
nunchaku
import
NunchakuSanaTransformer2DModel
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
transformer
=
NunchakuSanaTransformer2DModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-sana-1600m"
)
pipe
=
SanaPipeline
.
from_pretrained
(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
,
transformer
=
transformer
,
variant
=
"bf16"
,
torch_dtype
=
torch
.
bfloat16
,
).
to
(
"cuda"
)
pipe
.
vae
.
to
(
torch
.
bfloat16
)
pipe
.
text_encoder
.
to
(
torch
.
bfloat16
)
apply_cache_on_pipe
(
pipe
,
residual_diff_threshold
=
0.25
)
# WarmUp
prompt
=
"A cute 🐼 eating 🎋, ink drawing style"
image
=
pipe
(
prompt
=
prompt
,
height
=
1024
,
width
=
1024
,
guidance_scale
=
4.5
,
num_inference_steps
=
20
,
generator
=
torch
.
Generator
().
manual_seed
(
42
),
).
images
[
0
]
image
.
save
(
"sana_1600m-int4.png"
)
examples/
int4-
sana_1600m.py
→
examples/sana_1600m.py
View file @
ad8097b9
...
@@ -23,4 +23,4 @@ image = pipe(
...
@@ -23,4 +23,4 @@ image = pipe(
generator
=
torch
.
Generator
().
manual_seed
(
42
),
generator
=
torch
.
Generator
().
manual_seed
(
42
),
).
images
[
0
]
).
images
[
0
]
image
.
save
(
"sana_1600m.png"
)
image
.
save
(
"sana_1600m
-int4
.png"
)
examples/
int4-
sana_1600m_pag.py
→
examples/sana_1600m_pag.py
View file @
ad8097b9
...
@@ -24,4 +24,4 @@ image = pipe(
...
@@ -24,4 +24,4 @@ image = pipe(
pag_scale
=
2.0
,
pag_scale
=
2.0
,
num_inference_steps
=
20
,
num_inference_steps
=
20
,
).
images
[
0
]
).
images
[
0
]
image
.
save
(
"sana_1600m_pag.png"
)
image
.
save
(
"sana_1600m_pag
-int4
.png"
)
nunchaku/__version__.py
View file @
ad8097b9
__version__
=
"0.
1.4
"
__version__
=
"0.
2.0
"
comfyui/nodes
/__init__.py
→
nunchaku/caching
/__init__.py
View file @
ad8097b9
File moved
nunchaku/caching/diffusers_adapters/__init__.py
0 → 100644
View file @
ad8097b9
from
diffusers
import
DiffusionPipeline
def
apply_cache_on_pipe
(
pipe
:
DiffusionPipeline
,
*
args
,
**
kwargs
):
assert
isinstance
(
pipe
,
DiffusionPipeline
)
pipe_cls_name
=
pipe
.
__class__
.
__name__
if
pipe_cls_name
.
startswith
(
"Flux"
):
from
.flux
import
apply_cache_on_pipe
as
apply_cache_on_pipe_fn
elif
pipe_cls_name
.
startswith
(
"Sana"
):
from
.sana
import
apply_cache_on_pipe
as
apply_cache_on_pipe_fn
else
:
raise
ValueError
(
f
"Unknown pipeline class name:
{
pipe_cls_name
}
"
)
return
apply_cache_on_pipe_fn
(
pipe
,
*
args
,
**
kwargs
)
nunchaku/caching/diffusers_adapters/flux.py
0 → 100644
View file @
ad8097b9
import
functools
import
unittest
from
diffusers
import
DiffusionPipeline
,
FluxTransformer2DModel
from
torch
import
nn
from
...caching
import
utils
def
apply_cache_on_transformer
(
transformer
:
FluxTransformer2DModel
,
*
,
residual_diff_threshold
=
0.12
):
if
getattr
(
transformer
,
"_is_cached"
,
False
):
return
transformer
cached_transformer_blocks
=
nn
.
ModuleList
(
[
utils
.
FluxCachedTransformerBlocks
(
transformer
=
transformer
,
residual_diff_threshold
=
residual_diff_threshold
,
return_hidden_states_first
=
False
,
)
]
)
dummy_single_transformer_blocks
=
nn
.
ModuleList
()
original_forward
=
transformer
.
forward
@
functools
.
wraps
(
original_forward
)
def
new_forward
(
self
,
*
args
,
**
kwargs
):
with
(
unittest
.
mock
.
patch
.
object
(
self
,
"transformer_blocks"
,
cached_transformer_blocks
),
unittest
.
mock
.
patch
.
object
(
self
,
"single_transformer_blocks"
,
dummy_single_transformer_blocks
),
):
return
original_forward
(
*
args
,
**
kwargs
)
transformer
.
forward
=
new_forward
.
__get__
(
transformer
)
transformer
.
_is_cached
=
True
return
transformer
def
apply_cache_on_pipe
(
pipe
:
DiffusionPipeline
,
*
,
shallow_patch
:
bool
=
False
,
**
kwargs
):
if
not
getattr
(
pipe
,
"_is_cached"
,
False
):
original_call
=
pipe
.
__class__
.
__call__
@
functools
.
wraps
(
original_call
)
def
new_call
(
self
,
*
args
,
**
kwargs
):
with
utils
.
cache_context
(
utils
.
create_cache_context
()):
return
original_call
(
self
,
*
args
,
**
kwargs
)
pipe
.
__class__
.
__call__
=
new_call
pipe
.
__class__
.
_is_cached
=
True
if
not
shallow_patch
:
apply_cache_on_transformer
(
pipe
.
transformer
,
**
kwargs
)
return
pipe
nunchaku/caching/diffusers_adapters/sana.py
0 → 100644
View file @
ad8097b9
import
functools
import
unittest
import
torch
from
diffusers
import
DiffusionPipeline
,
SanaTransformer2DModel
from
...caching
import
utils
def
apply_cache_on_transformer
(
transformer
:
SanaTransformer2DModel
,
*
,
residual_diff_threshold
=
0.12
):
if
getattr
(
transformer
,
"_is_cached"
,
False
):
return
transformer
cached_transformer_blocks
=
torch
.
nn
.
ModuleList
(
[
utils
.
SanaCachedTransformerBlocks
(
transformer
=
transformer
,
residual_diff_threshold
=
residual_diff_threshold
,
)
]
)
original_forward
=
transformer
.
forward
@
functools
.
wraps
(
original_forward
)
def
new_forward
(
self
,
*
args
,
**
kwargs
):
with
unittest
.
mock
.
patch
.
object
(
self
,
"transformer_blocks"
,
cached_transformer_blocks
):
return
original_forward
(
*
args
,
**
kwargs
)
transformer
.
forward
=
new_forward
.
__get__
(
transformer
)
transformer
.
_is_cached
=
True
return
transformer
def
apply_cache_on_pipe
(
pipe
:
DiffusionPipeline
,
*
,
shallow_patch
:
bool
=
False
,
**
kwargs
):
if
not
getattr
(
pipe
,
"_is_cached"
,
False
):
original_call
=
pipe
.
__class__
.
__call__
@
functools
.
wraps
(
original_call
)
def
new_call
(
self
,
*
args
,
**
kwargs
):
with
utils
.
cache_context
(
utils
.
create_cache_context
()):
return
original_call
(
self
,
*
args
,
**
kwargs
)
pipe
.
__class__
.
__call__
=
new_call
pipe
.
__class__
.
_is_cached
=
True
if
not
shallow_patch
:
apply_cache_on_transformer
(
pipe
.
transformer
,
**
kwargs
)
return
pipe
nunchaku/caching/utils.py
0 → 100644
View file @
ad8097b9
# This caching functionality is largely brought from https://github.com/chengzeyi/ParaAttention/src/para_attn/first_block_cache/
import
contextlib
import
dataclasses
from
collections
import
defaultdict
from
typing
import
DefaultDict
,
Dict
,
Optional
import
torch
from
torch
import
nn
@
dataclasses
.
dataclass
class
CacheContext
:
buffers
:
Dict
[
str
,
torch
.
Tensor
]
=
dataclasses
.
field
(
default_factory
=
dict
)
incremental_name_counters
:
DefaultDict
[
str
,
int
]
=
dataclasses
.
field
(
default_factory
=
lambda
:
defaultdict
(
int
))
def
get_incremental_name
(
self
,
name
=
None
):
if
name
is
None
:
name
=
"default"
idx
=
self
.
incremental_name_counters
[
name
]
self
.
incremental_name_counters
[
name
]
+=
1
return
f
"
{
name
}
_
{
idx
}
"
def
reset_incremental_name
(
self
):
self
.
incremental_name_counters
.
clear
()
# @torch.compiler.disable # This is a torchscript feature
def
get_buffer
(
self
,
name
=
str
):
return
self
.
buffers
.
get
(
name
)
def
set_buffer
(
self
,
name
,
buffer
):
self
.
buffers
[
name
]
=
buffer
def
clear_buffers
(
self
):
self
.
buffers
.
clear
()
@
torch
.
compiler
.
disable
def
get_buffer
(
name
):
cache_context
=
get_current_cache_context
()
assert
cache_context
is
not
None
,
"cache_context must be set before"
return
cache_context
.
get_buffer
(
name
)
@
torch
.
compiler
.
disable
def
set_buffer
(
name
,
buffer
):
cache_context
=
get_current_cache_context
()
assert
cache_context
is
not
None
,
"cache_context must be set before"
cache_context
.
set_buffer
(
name
,
buffer
)
_current_cache_context
=
None
def
create_cache_context
():
return
CacheContext
()
def
get_current_cache_context
():
return
_current_cache_context
@
contextlib
.
contextmanager
def
cache_context
(
cache_context
):
global
_current_cache_context
old_cache_context
=
_current_cache_context
_current_cache_context
=
cache_context
try
:
yield
finally
:
_current_cache_context
=
old_cache_context
@
torch
.
compiler
.
disable
def
are_two_tensors_similar
(
t1
,
t2
,
*
,
threshold
,
parallelized
=
False
):
mean_diff
=
(
t1
-
t2
).
abs
().
mean
()
mean_t1
=
t1
.
abs
().
mean
()
diff
=
mean_diff
/
mean_t1
return
diff
.
item
()
<
threshold
@
torch
.
compiler
.
disable
def
apply_prev_hidden_states_residual
(
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_states_residual
=
get_buffer
(
"hidden_states_residual"
)
assert
hidden_states_residual
is
not
None
,
"hidden_states_residual must be set before"
hidden_states
=
hidden_states_residual
+
hidden_states
hidden_states
=
hidden_states
.
contiguous
()
if
encoder_hidden_states
is
not
None
:
encoder_hidden_states_residual
=
get_buffer
(
"encoder_hidden_states_residual"
)
assert
encoder_hidden_states_residual
is
not
None
,
"encoder_hidden_states_residual must be set before"
encoder_hidden_states
=
encoder_hidden_states_residual
+
encoder_hidden_states
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
return
hidden_states
,
encoder_hidden_states
@
torch
.
compiler
.
disable
def
get_can_use_cache
(
first_hidden_states_residual
,
threshold
,
parallelized
=
False
):
prev_first_hidden_states_residual
=
get_buffer
(
"first_hidden_states_residual"
)
can_use_cache
=
prev_first_hidden_states_residual
is
not
None
and
are_two_tensors_similar
(
prev_first_hidden_states_residual
,
first_hidden_states_residual
,
threshold
=
threshold
,
parallelized
=
parallelized
,
)
return
can_use_cache
class
SanaCachedTransformerBlocks
(
nn
.
Module
):
def
__init__
(
self
,
*
,
transformer
=
None
,
residual_diff_threshold
,
verbose
:
bool
=
False
,
):
super
().
__init__
()
self
.
transformer
=
transformer
self
.
transformer_blocks
=
transformer
.
transformer_blocks
self
.
residual_diff_threshold
=
residual_diff_threshold
self
.
verbose
=
verbose
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_hidden_states
,
encoder_attention_mask
=
None
,
timestep
=
None
,
post_patch_height
=
None
,
post_patch_width
=
None
,
):
batch_size
=
hidden_states
.
shape
[
0
]
if
self
.
residual_diff_threshold
<=
0.0
or
batch_size
>
2
:
if
batch_size
>
2
:
print
(
"Batch size > 2 (for SANA CFG)"
" currently not supported"
)
first_transformer_block
=
self
.
transformer_blocks
[
0
]
hidden_states
=
first_transformer_block
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
timestep
=
timestep
,
height
=
post_patch_height
,
width
=
post_patch_width
,
skip_first_layer
=
False
,
)
return
hidden_states
original_hidden_states
=
hidden_states
first_transformer_block
=
self
.
transformer_blocks
[
0
]
hidden_states
=
first_transformer_block
.
forward_layer_at
(
0
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
timestep
=
timestep
,
height
=
post_patch_height
,
width
=
post_patch_width
,
)
first_hidden_states_residual
=
hidden_states
-
original_hidden_states
del
original_hidden_states
can_use_cache
=
get_can_use_cache
(
first_hidden_states_residual
,
threshold
=
self
.
residual_diff_threshold
,
parallelized
=
self
.
transformer
is
not
None
and
getattr
(
self
.
transformer
,
"_is_parallelized"
,
False
),
)
torch
.
_dynamo
.
graph_break
()
if
can_use_cache
:
del
first_hidden_states_residual
if
self
.
verbose
:
print
(
"Cache hit!!!"
)
hidden_states
,
_
=
apply_prev_hidden_states_residual
(
hidden_states
,
None
)
else
:
if
self
.
verbose
:
print
(
"Cache miss!!!"
)
set_buffer
(
"first_hidden_states_residual"
,
first_hidden_states_residual
)
del
first_hidden_states_residual
hidden_states
,
hidden_states_residual
=
self
.
call_remaining_transformer_blocks
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
timestep
=
timestep
,
post_patch_height
=
post_patch_height
,
post_patch_width
=
post_patch_width
,
)
set_buffer
(
"hidden_states_residual"
,
hidden_states_residual
)
torch
.
_dynamo
.
graph_break
()
return
hidden_states
def
call_remaining_transformer_blocks
(
self
,
hidden_states
,
attention_mask
,
encoder_hidden_states
,
encoder_attention_mask
=
None
,
timestep
=
None
,
post_patch_height
=
None
,
post_patch_width
=
None
,
):
first_transformer_block
=
self
.
transformer_blocks
[
0
]
original_hidden_states
=
hidden_states
hidden_states
=
first_transformer_block
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
timestep
=
timestep
,
height
=
post_patch_height
,
width
=
post_patch_width
,
skip_first_layer
=
True
,
)
hidden_states_residual
=
hidden_states
-
original_hidden_states
return
hidden_states
,
hidden_states_residual
class
FluxCachedTransformerBlocks
(
nn
.
Module
):
def
__init__
(
self
,
*
,
transformer
=
None
,
residual_diff_threshold
,
return_hidden_states_first
=
True
,
return_hidden_states_only
=
False
,
verbose
:
bool
=
False
,
):
super
().
__init__
()
self
.
transformer
=
transformer
self
.
transformer_blocks
=
transformer
.
transformer_blocks
self
.
single_transformer_blocks
=
transformer
.
single_transformer_blocks
self
.
residual_diff_threshold
=
residual_diff_threshold
self
.
return_hidden_states_first
=
return_hidden_states_first
self
.
return_hidden_states_only
=
return_hidden_states_only
self
.
verbose
=
verbose
def
forward
(
self
,
hidden_states
,
encoder_hidden_states
,
*
args
,
**
kwargs
):
batch_size
=
hidden_states
.
shape
[
0
]
if
self
.
residual_diff_threshold
<=
0.0
or
batch_size
>
1
:
if
batch_size
>
1
:
print
(
"Batch size > 1 currently not supported"
)
first_transformer_block
=
self
.
transformer_blocks
[
0
]
encoder_hidden_states
,
hidden_states
=
first_transformer_block
(
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
*
args
,
**
kwargs
)
return
(
hidden_states
if
self
.
return_hidden_states_only
else
(
(
hidden_states
,
encoder_hidden_states
)
if
self
.
return_hidden_states_first
else
(
encoder_hidden_states
,
hidden_states
)
)
)
original_hidden_states
=
hidden_states
first_transformer_block
=
self
.
transformer_blocks
[
0
]
encoder_hidden_states
,
hidden_states
=
first_transformer_block
.
forward_layer_at
(
0
,
hidden_states
,
encoder_hidden_states
,
*
args
,
**
kwargs
)
first_hidden_states_residual
=
hidden_states
-
original_hidden_states
del
original_hidden_states
can_use_cache
=
get_can_use_cache
(
first_hidden_states_residual
,
threshold
=
self
.
residual_diff_threshold
,
parallelized
=
self
.
transformer
is
not
None
and
getattr
(
self
.
transformer
,
"_is_parallelized"
,
False
),
)
torch
.
_dynamo
.
graph_break
()
if
can_use_cache
:
del
first_hidden_states_residual
if
self
.
verbose
:
print
(
"Cache hit!!!"
)
hidden_states
,
encoder_hidden_states
=
apply_prev_hidden_states_residual
(
hidden_states
,
encoder_hidden_states
)
else
:
if
self
.
verbose
:
print
(
"Cache miss!!!"
)
set_buffer
(
"first_hidden_states_residual"
,
first_hidden_states_residual
)
del
first_hidden_states_residual
(
hidden_states
,
encoder_hidden_states
,
hidden_states_residual
,
encoder_hidden_states_residual
,
)
=
self
.
call_remaining_transformer_blocks
(
hidden_states
,
encoder_hidden_states
,
*
args
,
**
kwargs
)
set_buffer
(
"hidden_states_residual"
,
hidden_states_residual
)
set_buffer
(
"encoder_hidden_states_residual"
,
encoder_hidden_states_residual
)
torch
.
_dynamo
.
graph_break
()
return
(
hidden_states
if
self
.
return_hidden_states_only
else
(
(
hidden_states
,
encoder_hidden_states
)
if
self
.
return_hidden_states_first
else
(
encoder_hidden_states
,
hidden_states
)
)
)
def
call_remaining_transformer_blocks
(
self
,
hidden_states
,
encoder_hidden_states
,
*
args
,
**
kwargs
):
first_transformer_block
=
self
.
transformer_blocks
[
0
]
original_hidden_states
=
hidden_states
original_encoder_hidden_states
=
encoder_hidden_states
encoder_hidden_states
,
hidden_states
=
first_transformer_block
.
forward
(
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
skip_first_layer
=
True
,
*
args
,
**
kwargs
,
)
hidden_states
=
hidden_states
.
contiguous
()
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
hidden_states_residual
=
hidden_states
-
original_hidden_states
encoder_hidden_states_residual
=
encoder_hidden_states
-
original_encoder_hidden_states
return
hidden_states
,
encoder_hidden_states
,
hidden_states_residual
,
encoder_hidden_states_residual
nunchaku/csrc/flux.h
View file @
ad8097b9
...
@@ -10,22 +10,37 @@
...
@@ -10,22 +10,37 @@
class
QuantizedFluxModel
:
public
ModuleWrapper
<
FluxModel
>
{
// : public torch::CustomClassHolder {
class
QuantizedFluxModel
:
public
ModuleWrapper
<
FluxModel
>
{
// : public torch::CustomClassHolder {
public:
public:
void
init
(
bool
use_fp4
,
bool
offload
,
bool
bf16
,
int8_t
deviceId
)
{
void
init
(
bool
use_fp4
,
bool
offload
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedFluxModel"
);
spdlog
::
info
(
"Initializing QuantizedFluxModel on device {}"
,
deviceId
);
if
(
!
bf16
)
{
spdlog
::
info
(
"Use FP16 model"
);
}
if
(
offload
)
{
if
(
offload
)
{
spdlog
::
info
(
"Layer offloading enabled"
);
spdlog
::
info
(
"Layer offloading enabled"
);
}
}
ModuleWrapper
::
init
(
deviceId
);
CUDADeviceContext
ctx
(
this
->
deviceId
);
net
=
std
::
make_unique
<
FluxModel
>
(
use_fp4
,
offload
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
net
=
std
::
make_unique
<
FluxModel
>
(
use_fp4
,
offload
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
}
bool
isBF16
()
{
checkModel
();
return
net
->
dtype
==
Tensor
::
BF16
;
}
torch
::
Tensor
forward
(
torch
::
Tensor
forward
(
torch
::
Tensor
hidden_states
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
temb
,
torch
::
Tensor
temb
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_context
,
torch
::
Tensor
rotary_emb_context
,
torch
::
Tensor
rotary_emb_single
)
torch
::
Tensor
rotary_emb_single
,
std
::
optional
<
torch
::
Tensor
>
controlnet_block_samples
=
std
::
nullopt
,
std
::
optional
<
torch
::
Tensor
>
controlnet_single_block_samples
=
std
::
nullopt
,
bool
skip_first_layer
=
false
)
{
{
checkModel
();
checkModel
();
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedFluxModel forward"
);
spdlog
::
debug
(
"QuantizedFluxModel forward"
);
...
@@ -42,7 +57,10 @@ public:
...
@@ -42,7 +57,10 @@ public:
from_torch
(
temb
),
from_torch
(
temb
),
from_torch
(
rotary_emb_img
),
from_torch
(
rotary_emb_img
),
from_torch
(
rotary_emb_context
),
from_torch
(
rotary_emb_context
),
from_torch
(
rotary_emb_single
)
from_torch
(
rotary_emb_single
),
controlnet_block_samples
.
has_value
()
?
from_torch
(
controlnet_block_samples
.
value
().
contiguous
())
:
Tensor
{},
controlnet_single_block_samples
.
has_value
()
?
from_torch
(
controlnet_single_block_samples
.
value
().
contiguous
())
:
Tensor
{},
skip_first_layer
);
);
torch
::
Tensor
output
=
to_torch
(
result
);
torch
::
Tensor
output
=
to_torch
(
result
);
...
@@ -53,12 +71,16 @@ public:
...
@@ -53,12 +71,16 @@ public:
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
forward_layer
(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
forward_layer
(
int64_t
idx
,
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
temb
,
torch
::
Tensor
temb
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_context
)
torch
::
Tensor
rotary_emb_context
,
std
::
optional
<
torch
::
Tensor
>
controlnet_block_samples
=
std
::
nullopt
,
std
::
optional
<
torch
::
Tensor
>
controlnet_single_block_samples
=
std
::
nullopt
)
{
{
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedFluxModel forward_layer {}"
,
idx
);
spdlog
::
debug
(
"QuantizedFluxModel forward_layer {}"
,
idx
);
hidden_states
=
hidden_states
.
contiguous
();
hidden_states
=
hidden_states
.
contiguous
();
...
@@ -67,17 +89,19 @@ public:
...
@@ -67,17 +89,19 @@ public:
rotary_emb_img
=
rotary_emb_img
.
contiguous
();
rotary_emb_img
=
rotary_emb_img
.
contiguous
();
rotary_emb_context
=
rotary_emb_context
.
contiguous
();
rotary_emb_context
=
rotary_emb_context
.
contiguous
();
auto
&&
[
result_img
,
result_txt
]
=
net
->
transformer_blocks
.
at
(
idx
)
->
forward
(
auto
&&
[
hidden_states_
,
encoder_hidden_states_
]
=
net
->
forward_layer
(
idx
,
from_torch
(
hidden_states
),
from_torch
(
hidden_states
),
from_torch
(
encoder_hidden_states
),
from_torch
(
encoder_hidden_states
),
from_torch
(
temb
),
from_torch
(
temb
),
from_torch
(
rotary_emb_img
),
from_torch
(
rotary_emb_img
),
from_torch
(
rotary_emb_context
),
from_torch
(
rotary_emb_context
),
0.0
f
controlnet_block_samples
.
has_value
()
?
from_torch
(
controlnet_block_samples
.
value
().
contiguous
())
:
Tensor
{},
controlnet_single_block_samples
.
has_value
()
?
from_torch
(
controlnet_single_block_samples
.
value
().
contiguous
())
:
Tensor
{}
);
);
hidden_states
=
to_torch
(
result_img
);
hidden_states
=
to_torch
(
hidden_states_
);
encoder_hidden_states
=
to_torch
(
result_txt
);
encoder_hidden_states
=
to_torch
(
encoder_hidden_states_
);
Tensor
::
synchronizeDevice
();
Tensor
::
synchronizeDevice
();
return
{
hidden_states
,
encoder_hidden_states
};
return
{
hidden_states
,
encoder_hidden_states
};
...
@@ -85,10 +109,12 @@ public:
...
@@ -85,10 +109,12 @@ public:
torch
::
Tensor
forward_single_layer
(
torch
::
Tensor
forward_single_layer
(
int64_t
idx
,
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
temb
,
torch
::
Tensor
temb
,
torch
::
Tensor
rotary_emb_single
)
torch
::
Tensor
rotary_emb_single
)
{
{
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedFluxModel forward_single_layer {}"
,
idx
);
spdlog
::
debug
(
"QuantizedFluxModel forward_single_layer {}"
,
idx
);
hidden_states
=
hidden_states
.
contiguous
();
hidden_states
=
hidden_states
.
contiguous
();
...
@@ -115,6 +141,8 @@ public:
...
@@ -115,6 +141,8 @@ public:
throw
std
::
invalid_argument
(
"skipRanks must be multiples of 16"
);
throw
std
::
invalid_argument
(
"skipRanks must be multiples of 16"
);
}
}
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
info
(
"Set lora scale to {} (skip {} ranks)"
,
scale
,
skipRanks
);
spdlog
::
info
(
"Set lora scale to {} (skip {} ranks)"
,
scale
,
skipRanks
);
net
->
traverse
([
&
](
Module
*
module
)
{
net
->
traverse
([
&
](
Module
*
module
)
{
...
@@ -131,8 +159,20 @@ public:
...
@@ -131,8 +159,20 @@ public:
});
});
}
}
void
forceFP16Attention
(
bool
enable
)
{
void
setAttentionImpl
(
std
::
string
name
)
{
Attention
::
setForceFP16
(
net
.
get
(),
enable
);
if
(
name
.
empty
()
||
name
==
"default"
)
{
name
=
"flashattn2"
;
}
spdlog
::
info
(
"Set attention implementation to {}"
,
name
);
if
(
name
==
"flashattn2"
)
{
net
->
setAttentionImpl
(
AttentionImpl
::
FlashAttention2
);
}
else
if
(
name
==
"nunchaku-fp16"
)
{
net
->
setAttentionImpl
(
AttentionImpl
::
NunchakuFP16
);
}
else
{
throw
std
::
invalid_argument
(
spdlog
::
fmt_lib
::
format
(
"Invalid attention implementation {}"
,
name
));
}
}
}
};
};
\ No newline at end of file
nunchaku/csrc/gemm.h
View file @
ad8097b9
...
@@ -10,7 +10,7 @@ class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> {
...
@@ -10,7 +10,7 @@ class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> {
public:
public:
void
init
(
int64_t
in_features
,
int64_t
out_features
,
bool
bias
,
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
void
init
(
int64_t
in_features
,
int64_t
out_features
,
bool
bias
,
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedGEMM"
);
spdlog
::
info
(
"Initializing QuantizedGEMM"
);
size_t
val
=
0
;
size_t
val
=
0
;
checkCUDA
(
cudaDeviceSetLimit
(
cudaLimitStackSize
,
8192
));
checkCUDA
(
cudaDeviceSetLimit
(
cudaLimitStackSize
,
8192
));
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
...
@@ -27,7 +27,7 @@ public:
...
@@ -27,7 +27,7 @@ public:
x
=
x
.
contiguous
();
x
=
x
.
contiguous
();
Tensor
result
=
net
->
forward
(
from_torch
(
x
));
Tensor
result
=
net
->
forward
(
from_torch
(
x
));
torch
::
Tensor
output
=
to_torch
(
result
);
torch
::
Tensor
output
=
to_torch
(
result
);
Tensor
::
synchronizeDevice
();
Tensor
::
synchronizeDevice
();
...
@@ -48,7 +48,7 @@ public:
...
@@ -48,7 +48,7 @@ public:
const
int
M
=
x
.
shape
[
0
];
const
int
M
=
x
.
shape
[
0
];
const
int
K
=
x
.
shape
[
1
]
*
2
;
const
int
K
=
x
.
shape
[
1
]
*
2
;
assert
(
x
.
dtype
()
==
Tensor
::
INT8
);
assert
(
x
.
dtype
()
==
Tensor
::
INT8
);
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
...
@@ -67,7 +67,7 @@ public:
...
@@ -67,7 +67,7 @@ public:
const
int
offset
=
((
bm
*
(
K
/
WARP_K
)
+
bn
)
*
NUM_WARPS
+
warpId
)
*
WARP_M_TILES
*
WARP_SIZE
*
4
;
const
int
offset
=
((
bm
*
(
K
/
WARP_K
)
+
bn
)
*
NUM_WARPS
+
warpId
)
*
WARP_M_TILES
*
WARP_SIZE
*
4
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
assert
(
offset
+
i
<
x
.
numel
()
/
4
);
assert
(
static_cast
<
size_t
>
(
offset
+
i
)
<
x
.
numel
()
/
4
);
uint32_t
val
=
x
.
data_ptr
<
uint32_t
>
()[
offset
+
i
];
uint32_t
val
=
x
.
data_ptr
<
uint32_t
>
()[
offset
+
i
];
ss
<<
"{"
;
ss
<<
"{"
;
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
...
@@ -83,7 +83,7 @@ public:
...
@@ -83,7 +83,7 @@ public:
}
}
}
}
}
}
ss
<<
std
::
endl
;
ss
<<
std
::
endl
;
return
ss
.
str
();
return
ss
.
str
();
}
}
...
@@ -99,7 +99,7 @@ public:
...
@@ -99,7 +99,7 @@ public:
from_torch
(
x
),
from_torch
(
x
),
fuse_glu
fuse_glu
);
);
Tensor
act
=
qout
.
act
.
copy
(
Device
::
cpu
());
Tensor
act
=
qout
.
act
.
copy
(
Device
::
cpu
());
Tensor
ascales
=
qout
.
ascales
.
copy
(
Device
::
cpu
());
Tensor
ascales
=
qout
.
ascales
.
copy
(
Device
::
cpu
());
Tensor
lora_act
=
qout
.
lora_act
.
copy
(
Device
::
cpu
());
Tensor
lora_act
=
qout
.
lora_act
.
copy
(
Device
::
cpu
());
...
@@ -110,4 +110,4 @@ public:
...
@@ -110,4 +110,4 @@ public:
spdlog
::
debug
(
"ascales = {}"
,
dumpTensorBF16
(
ascales
));
spdlog
::
debug
(
"ascales = {}"
,
dumpTensorBF16
(
ascales
));
}
}
};
};
\ No newline at end of file
nunchaku/csrc/module.h
View file @
ad8097b9
...
@@ -9,7 +9,12 @@
...
@@ -9,7 +9,12 @@
template
<
typename
M
>
template
<
typename
M
>
class
ModuleWrapper
{
class
ModuleWrapper
{
public:
public:
void
init
(
int
deviceId
)
{
this
->
deviceId
=
deviceId
;
}
void
reset
()
{
void
reset
()
{
CUDADeviceContext
ctx
(
this
->
deviceId
);
debugContext
.
reset
();
debugContext
.
reset
();
net
.
reset
();
net
.
reset
();
Tensor
::
synchronizeDevice
();
Tensor
::
synchronizeDevice
();
...
@@ -20,6 +25,7 @@ public:
...
@@ -20,6 +25,7 @@ public:
void
load
(
std
::
string
path
,
bool
partial
=
false
)
{
void
load
(
std
::
string
path
,
bool
partial
=
false
)
{
checkModel
();
checkModel
();
CUDADeviceContext
ctx
(
this
->
deviceId
);
spdlog
::
info
(
"{} weights from {}"
,
partial
?
"Loading partial"
:
"Loading"
,
path
);
spdlog
::
info
(
"{} weights from {}"
,
partial
?
"Loading partial"
:
"Loading"
,
path
);
...
@@ -30,6 +36,19 @@ public:
...
@@ -30,6 +36,19 @@ public:
spdlog
::
info
(
"Done."
);
spdlog
::
info
(
"Done."
);
}
}
void
loadDict
(
std
::
map
<
std
::
string
,
torch
::
Tensor
>
dict
,
bool
partial
=
false
)
{
checkModel
();
CUDADeviceContext
ctx
(
this
->
deviceId
);
spdlog
::
info
(
"{} weights from pytorch"
,
partial
?
"Loading partial"
:
"Loading"
);
std
::
shared_ptr
<
TensorsProviderTorch
>
provider
=
std
::
make_shared
<
TensorsProviderTorch
>
(
std
::
move
(
dict
));
net
->
loadParams
(
*
provider
,
partial
);
Tensor
::
synchronizeDevice
();
spdlog
::
info
(
"Done."
);
}
void
startDebug
()
{
void
startDebug
()
{
debugContext
=
std
::
make_unique
<
DebugContext
>
();
debugContext
=
std
::
make_unique
<
DebugContext
>
();
}
}
...
@@ -38,6 +57,8 @@ public:
...
@@ -38,6 +57,8 @@ public:
}
}
auto
getDebugResults
()
{
auto
getDebugResults
()
{
CUDADeviceContext
ctx
(
this
->
deviceId
);
std
::
map
<
std
::
string
,
torch
::
Tensor
>
result
;
std
::
map
<
std
::
string
,
torch
::
Tensor
>
result
;
if
(
debugContext
)
{
if
(
debugContext
)
{
...
@@ -59,4 +80,6 @@ protected:
...
@@ -59,4 +80,6 @@ protected:
protected:
protected:
std
::
unique_ptr
<
M
>
net
;
std
::
unique_ptr
<
M
>
net
;
std
::
unique_ptr
<
DebugContext
>
debugContext
;
std
::
unique_ptr
<
DebugContext
>
debugContext
;
int
deviceId
=
-
1
;
};
};
\ No newline at end of file
nunchaku/csrc/ops.h
View file @
ad8097b9
...
@@ -32,7 +32,11 @@ namespace nunchaku::ops {
...
@@ -32,7 +32,11 @@ namespace nunchaku::ops {
bool
fuse_silu
,
bool
fuse_silu
,
bool
fp4
,
bool
fp4
,
float
alpha
,
float
alpha
,
std
::
optional
<
torch
::
Tensor
>
wcscales
std
::
optional
<
torch
::
Tensor
>
wcscales
,
std
::
optional
<
torch
::
Tensor
>
out_q
,
// packed attention [B, H, M, D]
std
::
optional
<
torch
::
Tensor
>
out_k
,
// packed attention [B, H, M, D]
std
::
optional
<
torch
::
Tensor
>
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
)
{
)
{
spdlog
::
trace
(
"running gemm_w4a4: "
);
spdlog
::
trace
(
"running gemm_w4a4: "
);
...
@@ -70,11 +74,31 @@ namespace nunchaku::ops {
...
@@ -70,11 +74,31 @@ namespace nunchaku::ops {
fuse_silu
,
fuse_silu
,
fp4
,
fp4
,
alpha
,
alpha
,
getTensor
(
wcscales
)
getTensor
(
wcscales
),
getTensor
(
out_q
),
getTensor
(
out_k
),
getTensor
(
out_v
),
attn_tokens
);
);
// Tensor::synchronizeDevice();
// Tensor::synchronizeDevice();
}
}
void
attention_fp16
(
torch
::
Tensor
q
,
// packed [Batch, Head, TokensQ, HEAD_DIM]
torch
::
Tensor
k
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
torch
::
Tensor
v
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
torch
::
Tensor
o
,
// linear [Batch, TokensQ, Head * HEAD_DIM]
float
scale
)
{
nunchaku
::
kernels
::
attention_fp16
(
from_torch
(
q
),
from_torch
(
k
),
from_torch
(
v
),
from_torch
(
o
),
scale
);
}
torch
::
Tensor
gemv_awq
(
torch
::
Tensor
gemv_awq
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_kernel
,
...
@@ -122,6 +146,36 @@ namespace nunchaku::ops {
...
@@ -122,6 +146,36 @@ namespace nunchaku::ops {
return
output
;
return
output
;
}
}
void
test_rmsnorm_rope
(
torch
::
Tensor
input
,
torch
::
Tensor
output
,
torch
::
Tensor
norm_q
,
torch
::
Tensor
norm_k
,
torch
::
Tensor
rotary_emb
)
{
nunchaku
::
kernels
::
test_rmsnorm_rope
(
from_torch
(
input
),
from_torch
(
output
),
from_torch
(
norm_q
),
from_torch
(
norm_k
),
from_torch
(
rotary_emb
)
);
}
void
test_pack_qkv
(
torch
::
Tensor
input
,
torch
::
Tensor
out_q
,
torch
::
Tensor
out_k
,
torch
::
Tensor
out_v
,
int
numTokens
)
{
nunchaku
::
kernels
::
test_pack_qkv
(
from_torch
(
input
),
from_torch
(
out_q
),
from_torch
(
out_k
),
from_torch
(
out_v
),
numTokens
);
}
};
};
\ No newline at end of file
nunchaku/csrc/pybind.cpp
View file @
ad8097b9
...
@@ -18,18 +18,42 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -18,18 +18,42 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"deviceId"
)
py
::
arg
(
"deviceId"
)
)
)
.
def
(
"reset"
,
&
QuantizedFluxModel
::
reset
)
.
def
(
"reset"
,
&
QuantizedFluxModel
::
reset
)
.
def
(
"load"
,
&
QuantizedFluxModel
::
load
,
.
def
(
"load"
,
&
QuantizedFluxModel
::
load
,
py
::
arg
(
"path"
),
py
::
arg
(
"path"
),
py
::
arg
(
"partial"
)
=
false
py
::
arg
(
"partial"
)
=
false
)
)
.
def
(
"forward"
,
&
QuantizedFluxModel
::
forward
)
.
def
(
"loadDict"
,
&
QuantizedFluxModel
::
loadDict
,
.
def
(
"forward_layer"
,
&
QuantizedFluxModel
::
forward_layer
)
py
::
arg
(
"dict"
),
py
::
arg
(
"partial"
)
=
false
)
.
def
(
"forward"
,
&
QuantizedFluxModel
::
forward
,
py
::
arg
(
"hidden_states"
),
py
::
arg
(
"encoder_hidden_states"
),
py
::
arg
(
"temb"
),
py
::
arg
(
"rotary_emb_img"
),
py
::
arg
(
"rotary_emb_context"
),
py
::
arg
(
"rotary_emb_single"
),
py
::
arg
(
"controlnet_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"controlnet_single_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"skip_first_layer"
)
=
false
)
.
def
(
"forward_layer"
,
&
QuantizedFluxModel
::
forward_layer
,
py
::
arg
(
"idx"
),
py
::
arg
(
"hidden_states"
),
py
::
arg
(
"encoder_hidden_states"
),
py
::
arg
(
"temb"
),
py
::
arg
(
"rotary_emb_img"
),
py
::
arg
(
"rotary_emb_context"
),
py
::
arg
(
"controlnet_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"controlnet_single_block_samples"
)
=
py
::
none
()
)
.
def
(
"forward_single_layer"
,
&
QuantizedFluxModel
::
forward_single_layer
)
.
def
(
"forward_single_layer"
,
&
QuantizedFluxModel
::
forward_single_layer
)
.
def
(
"startDebug"
,
&
QuantizedFluxModel
::
startDebug
)
.
def
(
"startDebug"
,
&
QuantizedFluxModel
::
startDebug
)
.
def
(
"stopDebug"
,
&
QuantizedFluxModel
::
stopDebug
)
.
def
(
"stopDebug"
,
&
QuantizedFluxModel
::
stopDebug
)
.
def
(
"getDebugResults"
,
&
QuantizedFluxModel
::
getDebugResults
)
.
def
(
"getDebugResults"
,
&
QuantizedFluxModel
::
getDebugResults
)
.
def
(
"setLoraScale"
,
&
QuantizedFluxModel
::
setLoraScale
)
.
def
(
"setLoraScale"
,
&
QuantizedFluxModel
::
setLoraScale
)
.
def
(
"forceFP16Attention"
,
&
QuantizedFluxModel
::
forceFP16Attention
)
.
def
(
"setAttentionImpl"
,
&
QuantizedFluxModel
::
setAttentionImpl
)
.
def
(
"isBF16"
,
&
QuantizedFluxModel
::
isBF16
)
;
;
py
::
class_
<
QuantizedSanaModel
>
(
m
,
"QuantizedSanaModel"
)
py
::
class_
<
QuantizedSanaModel
>
(
m
,
"QuantizedSanaModel"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
...
@@ -41,10 +65,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -41,10 +65,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"deviceId"
)
py
::
arg
(
"deviceId"
)
)
)
.
def
(
"reset"
,
&
QuantizedSanaModel
::
reset
)
.
def
(
"reset"
,
&
QuantizedSanaModel
::
reset
)
.
def
(
"load"
,
&
QuantizedSanaModel
::
load
,
.
def
(
"load"
,
&
QuantizedSanaModel
::
load
,
py
::
arg
(
"path"
),
py
::
arg
(
"path"
),
py
::
arg
(
"partial"
)
=
false
py
::
arg
(
"partial"
)
=
false
)
)
.
def
(
"loadDict"
,
&
QuantizedSanaModel
::
loadDict
,
py
::
arg
(
"dict"
),
py
::
arg
(
"partial"
)
=
false
)
.
def
(
"forward"
,
&
QuantizedSanaModel
::
forward
)
.
def
(
"forward"
,
&
QuantizedSanaModel
::
forward
)
.
def
(
"forward_layer"
,
&
QuantizedSanaModel
::
forward_layer
)
.
def
(
"forward_layer"
,
&
QuantizedSanaModel
::
forward_layer
)
.
def
(
"startDebug"
,
&
QuantizedSanaModel
::
startDebug
)
.
def
(
"startDebug"
,
&
QuantizedSanaModel
::
startDebug
)
...
@@ -74,15 +102,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -74,15 +102,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
;
;
m
.
def_submodule
(
"ops"
)
m
.
def_submodule
(
"ops"
)
.
def
(
"gemm_w4a4"
,
nunchaku
::
ops
::
gemm_w4a4
)
.
def
(
"attention_fp16"
,
nunchaku
::
ops
::
attention_fp16
)
.
def
(
"gemm_awq"
,
nunchaku
::
ops
::
gemm_awq
)
.
def
(
"gemm_awq"
,
nunchaku
::
ops
::
gemm_awq
)
.
def
(
"gemv_awq"
,
nunchaku
::
ops
::
gemv_awq
)
.
def
(
"gemv_awq"
,
nunchaku
::
ops
::
gemv_awq
)
.
def
(
"test_rmsnorm_rope"
,
nunchaku
::
ops
::
test_rmsnorm_rope
)
.
def
(
"test_pack_qkv"
,
nunchaku
::
ops
::
test_pack_qkv
)
;
;
m
.
def_submodule
(
"utils"
)
m
.
def_submodule
(
"utils"
)
.
def
(
"set_log_level"
,
[](
const
std
::
string
&
level
)
{
.
def
(
"set_log_level"
,
[](
const
std
::
string
&
level
)
{
spdlog
::
set_level
(
spdlog
::
level
::
from_str
(
level
));
spdlog
::
set_level
(
spdlog
::
level
::
from_str
(
level
));
})
})
.
def
(
"set_cuda_stack_limit"
,
nunchaku
::
utils
::
set_cuda_stack_limit
)
.
def
(
"disable_memory_auto_release"
,
nunchaku
::
utils
::
disable_memory_auto_release
)
.
def
(
"disable_memory_auto_release"
,
nunchaku
::
utils
::
disable_memory_auto_release
)
.
def
(
"trim_memory"
,
nunchaku
::
utils
::
trim_memory
)
.
def
(
"trim_memory"
,
nunchaku
::
utils
::
trim_memory
)
.
def
(
"set_faster_i2f_mode"
,
nunchaku
::
utils
::
set_faster_i2f_mode
)
;
;
}
}
nunchaku/csrc/sana.h
View file @
ad8097b9
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
class
QuantizedSanaModel
:
public
ModuleWrapper
<
SanaModel
>
{
class
QuantizedSanaModel
:
public
ModuleWrapper
<
SanaModel
>
{
public:
public:
void
init
(
pybind11
::
dict
config
,
std
::
vector
<
int
>
pag_layers
,
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
void
init
(
pybind11
::
dict
config
,
std
::
vector
<
int
>
pag_layers
,
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedSanaModel
"
);
spdlog
::
info
(
"Initializing QuantizedSanaModel
on device {}"
,
deviceId
);
SanaConfig
cfg
{
SanaConfig
cfg
{
.
num_layers
=
config
[
"num_layers"
].
cast
<
int
>
(),
.
num_layers
=
config
[
"num_layers"
].
cast
<
int
>
(),
.
num_attention_heads
=
config
[
"num_attention_heads"
].
cast
<
int
>
(),
.
num_attention_heads
=
config
[
"num_attention_heads"
].
cast
<
int
>
(),
...
@@ -19,21 +19,26 @@ public:
...
@@ -19,21 +19,26 @@ public:
.
pag_layers
=
pag_layers
,
.
pag_layers
=
pag_layers
,
.
use_fp4
=
use_fp4
,
.
use_fp4
=
use_fp4
,
};
};
ModuleWrapper
::
init
(
deviceId
);
CUDADeviceContext
ctx
(
this
->
deviceId
);
net
=
std
::
make_unique
<
SanaModel
>
(
cfg
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
net
=
std
::
make_unique
<
SanaModel
>
(
cfg
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
}
torch
::
Tensor
forward
(
torch
::
Tensor
forward
(
torch
::
Tensor
hidden_states
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
timestep
,
torch
::
Tensor
timestep
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_txt
,
torch
::
Tensor
cu_seqlens_txt
,
int
H
,
int
H
,
int
W
,
int
W
,
bool
pag
,
bool
pag
,
bool
cfg
)
bool
cfg
,
bool
skip_first_layer
=
false
)
{
{
checkModel
();
checkModel
();
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedSanaModel forward"
);
spdlog
::
debug
(
"QuantizedSanaModel forward"
);
...
@@ -50,7 +55,8 @@ public:
...
@@ -50,7 +55,8 @@ public:
from_torch
(
cu_seqlens_img
),
from_torch
(
cu_seqlens_img
),
from_torch
(
cu_seqlens_txt
),
from_torch
(
cu_seqlens_txt
),
H
,
W
,
H
,
W
,
pag
,
cfg
pag
,
cfg
,
skip_first_layer
);
);
torch
::
Tensor
output
=
to_torch
(
result
);
torch
::
Tensor
output
=
to_torch
(
result
);
...
@@ -61,17 +67,18 @@ public:
...
@@ -61,17 +67,18 @@ public:
torch
::
Tensor
forward_layer
(
torch
::
Tensor
forward_layer
(
int64_t
idx
,
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
timestep
,
torch
::
Tensor
timestep
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_txt
,
torch
::
Tensor
cu_seqlens_txt
,
int
H
,
int
H
,
int
W
,
int
W
,
bool
pag
,
bool
pag
,
bool
cfg
)
bool
cfg
)
{
{
checkModel
();
checkModel
();
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedSanaModel forward_layer {}"
,
idx
);
spdlog
::
debug
(
"QuantizedSanaModel forward_layer {}"
,
idx
);
...
...
nunchaku/csrc/utils.h
View file @
ad8097b9
...
@@ -2,9 +2,17 @@
...
@@ -2,9 +2,17 @@
#include "common.h"
#include "common.h"
#include "Tensor.h"
#include "Tensor.h"
#include "kernels/zgemm/zgemm.h"
namespace
nunchaku
::
utils
{
namespace
nunchaku
::
utils
{
void
set_cuda_stack_limit
(
int64_t
newval
)
{
size_t
val
=
0
;
checkCUDA
(
cudaDeviceSetLimit
(
cudaLimitStackSize
,
(
size_t
)
newval
));
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
spdlog
::
debug
(
"Stack={}"
,
val
);
}
void
disable_memory_auto_release
()
{
void
disable_memory_auto_release
()
{
int
device
;
int
device
;
checkCUDA
(
cudaGetDevice
(
&
device
));
checkCUDA
(
cudaGetDevice
(
&
device
));
...
@@ -23,4 +31,9 @@ namespace nunchaku::utils {
...
@@ -23,4 +31,9 @@ namespace nunchaku::utils {
checkCUDA
(
cudaMemPoolTrimTo
(
mempool
,
bytesToKeep
));
checkCUDA
(
cudaMemPoolTrimTo
(
mempool
,
bytesToKeep
));
}
}
void
set_faster_i2f_mode
(
std
::
string
mode
)
{
spdlog
::
info
(
"Set fasteri2f mode to {}"
,
mode
);
kernels
::
set_faster_i2f_mode
(
mode
);
}
};
};
\ No newline at end of file
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment