Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
235238bd
Commit
235238bd
authored
Apr 02, 2025
by
Hyunsung Lee
Committed by
Zhekai Zhang
Apr 04, 2025
Browse files
Add controlnet
parent
63913f29
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
466 additions
and
65 deletions
+466
-65
examples/controlnet-flux-cache.py
examples/controlnet-flux-cache.py
+60
-0
examples/controlnet-flux.py
examples/controlnet-flux.py
+59
-0
examples/ref-controlnet.py
examples/ref-controlnet.py
+48
-0
nunchaku/csrc/flux.h
nunchaku/csrc/flux.h
+13
-5
nunchaku/csrc/gemm.h
nunchaku/csrc/gemm.h
+5
-5
nunchaku/csrc/pybind.cpp
nunchaku/csrc/pybind.cpp
+25
-6
nunchaku/models/transformers/transformer_flux.py
nunchaku/models/transformers/transformer_flux.py
+118
-2
src/FluxModel.cpp
src/FluxModel.cpp
+117
-43
src/FluxModel.h
src/FluxModel.h
+21
-4
No files found.
examples/controlnet-flux-cache.py
0 → 100644
View file @
235238bd
import
random
import
torch
from
diffusers
import
FluxControlNetPipeline
,
FluxControlNetModel
from
diffusers.models
import
FluxMultiControlNetModel
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
diffusers.utils
import
load_image
import
numpy
as
np
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
SEED
=
42
random
.
seed
(
SEED
)
np
.
random
.
seed
(
SEED
)
torch
.
manual_seed
(
SEED
)
torch
.
cuda
.
manual_seed_all
(
SEED
)
base_model
=
'black-forest-labs/FLUX.1-dev'
controlnet_model_union
=
'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union
=
FluxControlNetModel
.
from_pretrained
(
controlnet_model_union
,
torch_dtype
=
torch
.
bfloat16
)
controlnet
=
FluxMultiControlNetModel
([
controlnet_union
])
# we always recommend loading via FluxMultiControlNetModel
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-dev"
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
pipe
=
FluxControlNetPipeline
.
from_pretrained
(
base_model
,
transformer
=
transformer
,
controlnet
=
controlnet
,
torch_dtype
=
torch
.
bfloat16
)
apply_cache_on_pipe
(
pipe
,
residual_diff_threshold
=
0.12
)
pipe
.
to
(
"cuda"
)
prompt
=
'A anime style girl with messy beach waves.'
control_image_depth
=
load_image
(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg"
)
control_mode_depth
=
2
control_image_canny
=
load_image
(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg"
)
control_mode_canny
=
0
width
,
height
=
control_image_depth
.
size
image
=
pipe
(
prompt
,
control_image
=
[
control_image_depth
,
control_image_canny
],
control_mode
=
[
control_mode_depth
,
control_mode_canny
],
width
=
width
,
height
=
height
,
controlnet_conditioning_scale
=
[
0.3
,
0.1
],
num_inference_steps
=
28
,
guidance_scale
=
3.5
,
generator
=
torch
.
manual_seed
(
SEED
),
).
images
[
0
]
image
.
save
(
"nunchaku-controlnet-flux.1-dev.png"
)
examples/controlnet-flux.py
0 → 100644
View file @
235238bd
import
random
import
torch
from
diffusers
import
FluxControlNetPipeline
,
FluxControlNetModel
from
diffusers.models
import
FluxMultiControlNetModel
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
diffusers.utils
import
load_image
import
numpy
as
np
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
SEED
=
42
random
.
seed
(
SEED
)
np
.
random
.
seed
(
SEED
)
torch
.
manual_seed
(
SEED
)
torch
.
cuda
.
manual_seed_all
(
SEED
)
base_model
=
'black-forest-labs/FLUX.1-dev'
controlnet_model_union
=
'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union
=
FluxControlNetModel
.
from_pretrained
(
controlnet_model_union
,
torch_dtype
=
torch
.
bfloat16
)
controlnet
=
FluxMultiControlNetModel
([
controlnet_union
])
# we always recommend loading via FluxMultiControlNetModel
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-dev"
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
pipe
=
FluxControlNetPipeline
.
from_pretrained
(
base_model
,
transformer
=
transformer
,
controlnet
=
controlnet
,
torch_dtype
=
torch
.
bfloat16
)
pipe
.
to
(
"cuda"
)
prompt
=
'A anime style girl with messy beach waves.'
control_image_depth
=
load_image
(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg"
)
control_mode_depth
=
2
control_image_canny
=
load_image
(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg"
)
control_mode_canny
=
0
width
,
height
=
control_image_depth
.
size
image
=
pipe
(
prompt
,
control_image
=
[
control_image_depth
,
control_image_canny
],
control_mode
=
[
control_mode_depth
,
control_mode_canny
],
width
=
width
,
height
=
height
,
controlnet_conditioning_scale
=
[
0.3
,
0.1
],
num_inference_steps
=
28
,
guidance_scale
=
3.5
,
generator
=
torch
.
manual_seed
(
SEED
),
).
images
[
0
]
image
.
save
(
"nunchaku-controlnet-flux.1-dev.png"
)
examples/ref-controlnet.py
0 → 100644
View file @
235238bd
import
random
import
torch
from
diffusers
import
FluxControlNetPipeline
,
FluxControlNetModel
from
diffusers.models
import
FluxMultiControlNetModel
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
diffusers.utils
import
load_image
import
numpy
as
np
SEED
=
42
random
.
seed
(
SEED
)
np
.
random
.
seed
(
SEED
)
torch
.
manual_seed
(
SEED
)
torch
.
cuda
.
manual_seed_all
(
SEED
)
base_model
=
'black-forest-labs/FLUX.1-dev'
controlnet_model_union
=
'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union
=
FluxControlNetModel
.
from_pretrained
(
controlnet_model_union
,
torch_dtype
=
torch
.
bfloat16
)
controlnet
=
FluxMultiControlNetModel
([
controlnet_union
])
# we always recommend loading via FluxMultiControlNetModel
pipe
=
FluxControlNetPipeline
.
from_pretrained
(
base_model
,
controlnet
=
controlnet
,
torch_dtype
=
torch
.
bfloat16
)
pipe
.
to
(
"cuda"
)
prompt
=
'A anime style girl with messy beach waves.'
control_image_depth
=
load_image
(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg"
)
control_mode_depth
=
2
control_image_canny
=
load_image
(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg"
)
control_mode_canny
=
0
width
,
height
=
control_image_depth
.
size
image
=
pipe
(
prompt
,
control_image
=
[
control_image_depth
,
control_image_canny
],
control_mode
=
[
control_mode_depth
,
control_mode_canny
],
width
=
width
,
height
=
height
,
controlnet_conditioning_scale
=
[
0.3
,
0.1
],
num_inference_steps
=
28
,
guidance_scale
=
3.5
,
generator
=
torch
.
manual_seed
(
SEED
),
).
images
[
0
]
image
.
save
(
"reference-controlnet-flux.1-dev.png"
)
nunchaku/csrc/flux.h
View file @
235238bd
...
@@ -35,6 +35,8 @@ public:
...
@@ -35,6 +35,8 @@ public:
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_context
,
torch
::
Tensor
rotary_emb_context
,
torch
::
Tensor
rotary_emb_single
,
torch
::
Tensor
rotary_emb_single
,
std
::
optional
<
torch
::
Tensor
>
controlnet_block_samples
=
std
::
nullopt
,
std
::
optional
<
torch
::
Tensor
>
controlnet_single_block_samples
=
std
::
nullopt
,
bool
skip_first_layer
=
false
)
bool
skip_first_layer
=
false
)
{
{
checkModel
();
checkModel
();
...
@@ -56,6 +58,8 @@ public:
...
@@ -56,6 +58,8 @@ public:
from_torch
(
rotary_emb_img
),
from_torch
(
rotary_emb_img
),
from_torch
(
rotary_emb_context
),
from_torch
(
rotary_emb_context
),
from_torch
(
rotary_emb_single
),
from_torch
(
rotary_emb_single
),
controlnet_block_samples
.
has_value
()
?
from_torch
(
controlnet_block_samples
.
value
().
contiguous
())
:
Tensor
{},
controlnet_single_block_samples
.
has_value
()
?
from_torch
(
controlnet_single_block_samples
.
value
().
contiguous
())
:
Tensor
{},
skip_first_layer
skip_first_layer
);
);
...
@@ -71,7 +75,9 @@ public:
...
@@ -71,7 +75,9 @@ public:
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
temb
,
torch
::
Tensor
temb
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_context
)
torch
::
Tensor
rotary_emb_context
,
std
::
optional
<
torch
::
Tensor
>
controlnet_block_samples
=
std
::
nullopt
,
std
::
optional
<
torch
::
Tensor
>
controlnet_single_block_samples
=
std
::
nullopt
)
{
{
CUDADeviceContext
ctx
(
deviceId
);
CUDADeviceContext
ctx
(
deviceId
);
...
@@ -83,17 +89,19 @@ public:
...
@@ -83,17 +89,19 @@ public:
rotary_emb_img
=
rotary_emb_img
.
contiguous
();
rotary_emb_img
=
rotary_emb_img
.
contiguous
();
rotary_emb_context
=
rotary_emb_context
.
contiguous
();
rotary_emb_context
=
rotary_emb_context
.
contiguous
();
auto
&&
[
result_img
,
result_txt
]
=
net
->
transformer_blocks
.
at
(
idx
)
->
forward
(
auto
&&
[
hidden_states_
,
encoder_hidden_states_
]
=
net
->
forward_layer
(
idx
,
from_torch
(
hidden_states
),
from_torch
(
hidden_states
),
from_torch
(
encoder_hidden_states
),
from_torch
(
encoder_hidden_states
),
from_torch
(
temb
),
from_torch
(
temb
),
from_torch
(
rotary_emb_img
),
from_torch
(
rotary_emb_img
),
from_torch
(
rotary_emb_context
),
from_torch
(
rotary_emb_context
),
0.0
f
controlnet_block_samples
.
has_value
()
?
from_torch
(
controlnet_block_samples
.
value
().
contiguous
())
:
Tensor
{},
controlnet_single_block_samples
.
has_value
()
?
from_torch
(
controlnet_single_block_samples
.
value
().
contiguous
())
:
Tensor
{}
);
);
hidden_states
=
to_torch
(
result_img
);
hidden_states
=
to_torch
(
hidden_states_
);
encoder_hidden_states
=
to_torch
(
result_txt
);
encoder_hidden_states
=
to_torch
(
encoder_hidden_states_
);
Tensor
::
synchronizeDevice
();
Tensor
::
synchronizeDevice
();
return
{
hidden_states
,
encoder_hidden_states
};
return
{
hidden_states
,
encoder_hidden_states
};
...
...
nunchaku/csrc/gemm.h
View file @
235238bd
nunchaku/csrc/pybind.cpp
View file @
235238bd
...
@@ -26,8 +26,27 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -26,8 +26,27 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"dict"
),
py
::
arg
(
"dict"
),
py
::
arg
(
"partial"
)
=
false
py
::
arg
(
"partial"
)
=
false
)
)
.
def
(
"forward"
,
&
QuantizedFluxModel
::
forward
)
.
def
(
"forward"
,
&
QuantizedFluxModel
::
forward
,
.
def
(
"forward_layer"
,
&
QuantizedFluxModel
::
forward_layer
)
py
::
arg
(
"hidden_states"
),
py
::
arg
(
"encoder_hidden_states"
),
py
::
arg
(
"temb"
),
py
::
arg
(
"rotary_emb_img"
),
py
::
arg
(
"rotary_emb_context"
),
py
::
arg
(
"rotary_emb_single"
),
py
::
arg
(
"controlnet_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"controlnet_single_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"skip_first_layer"
)
=
false
)
.
def
(
"forward_layer"
,
&
QuantizedFluxModel
::
forward_layer
,
py
::
arg
(
"idx"
),
py
::
arg
(
"hidden_states"
),
py
::
arg
(
"encoder_hidden_states"
),
py
::
arg
(
"temb"
),
py
::
arg
(
"rotary_emb_img"
),
py
::
arg
(
"rotary_emb_context"
),
py
::
arg
(
"controlnet_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"controlnet_single_block_samples"
)
=
py
::
none
()
)
.
def
(
"forward_single_layer"
,
&
QuantizedFluxModel
::
forward_single_layer
)
.
def
(
"forward_single_layer"
,
&
QuantizedFluxModel
::
forward_single_layer
)
.
def
(
"startDebug"
,
&
QuantizedFluxModel
::
startDebug
)
.
def
(
"startDebug"
,
&
QuantizedFluxModel
::
startDebug
)
.
def
(
"stopDebug"
,
&
QuantizedFluxModel
::
stopDebug
)
.
def
(
"stopDebug"
,
&
QuantizedFluxModel
::
stopDebug
)
...
...
nunchaku/models/transformers/transformer_flux.py
View file @
235238bd
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
logging
import
logging
import
os
import
os
...
@@ -5,6 +7,7 @@ import diffusers
...
@@ -5,6 +7,7 @@ import diffusers
import
torch
import
torch
from
diffusers
import
FluxTransformer2DModel
from
diffusers
import
FluxTransformer2DModel
from
diffusers.configuration_utils
import
register_to_config
from
diffusers.configuration_utils
import
register_to_config
from
diffusers.models.modeling_outputs
import
Transformer2DModelOutput
from
huggingface_hub
import
utils
from
huggingface_hub
import
utils
from
packaging.version
import
Version
from
packaging.version
import
Version
from
safetensors.torch
import
load_file
,
save_file
from
safetensors.torch
import
load_file
,
save_file
...
@@ -62,6 +65,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -62,6 +65,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
encoder_hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
image_rotary_emb
:
torch
.
Tensor
,
image_rotary_emb
:
torch
.
Tensor
,
joint_attention_kwargs
=
None
,
joint_attention_kwargs
=
None
,
controlnet_block_samples
=
None
,
controlnet_single_block_samples
=
None
,
skip_first_layer
=
False
,
skip_first_layer
=
False
,
):
):
batch_size
=
hidden_states
.
shape
[
0
]
batch_size
=
hidden_states
.
shape
[
0
]
...
@@ -76,6 +81,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -76,6 +81,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb
=
temb
.
to
(
self
.
dtype
).
to
(
self
.
device
)
temb
=
temb
.
to
(
self
.
dtype
).
to
(
self
.
device
)
image_rotary_emb
=
image_rotary_emb
.
to
(
self
.
device
)
image_rotary_emb
=
image_rotary_emb
.
to
(
self
.
device
)
if
controlnet_block_samples
is
not
None
:
controlnet_block_samples
=
torch
.
stack
(
controlnet_block_samples
).
to
(
self
.
device
)
if
controlnet_single_block_samples
is
not
None
:
controlnet_single_block_samples
=
torch
.
stack
(
controlnet_single_block_samples
).
to
(
self
.
device
)
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
shape
[
0
]
==
1
assert
image_rotary_emb
.
shape
[
0
]
==
1
assert
image_rotary_emb
.
shape
[
1
]
==
1
assert
image_rotary_emb
.
shape
[
1
]
==
1
...
@@ -89,7 +99,6 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -89,7 +99,6 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_txt
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_txt
,
256
,
1
))
rotary_emb_txt
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_txt
,
256
,
1
))
rotary_emb_img
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_img
,
256
,
1
))
rotary_emb_img
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_img
,
256
,
1
))
rotary_emb_single
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_single
,
256
,
1
))
rotary_emb_single
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_single
,
256
,
1
))
hidden_states
=
self
.
m
.
forward
(
hidden_states
=
self
.
m
.
forward
(
hidden_states
,
hidden_states
,
encoder_hidden_states
,
encoder_hidden_states
,
...
@@ -97,6 +106,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -97,6 +106,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img
,
rotary_emb_img
,
rotary_emb_txt
,
rotary_emb_txt
,
rotary_emb_single
,
rotary_emb_single
,
controlnet_block_samples
,
controlnet_single_block_samples
,
skip_first_layer
,
skip_first_layer
,
)
)
...
@@ -115,6 +126,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -115,6 +126,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb
:
torch
.
Tensor
,
temb
:
torch
.
Tensor
,
image_rotary_emb
:
torch
.
Tensor
,
image_rotary_emb
:
torch
.
Tensor
,
joint_attention_kwargs
=
None
,
joint_attention_kwargs
=
None
,
controlnet_block_samples
=
None
,
controlnet_single_block_samples
=
None
):
):
batch_size
=
hidden_states
.
shape
[
0
]
batch_size
=
hidden_states
.
shape
[
0
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
...
@@ -128,6 +141,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -128,6 +141,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb
=
temb
.
to
(
self
.
dtype
).
to
(
self
.
device
)
temb
=
temb
.
to
(
self
.
dtype
).
to
(
self
.
device
)
image_rotary_emb
=
image_rotary_emb
.
to
(
self
.
device
)
image_rotary_emb
=
image_rotary_emb
.
to
(
self
.
device
)
if
controlnet_block_samples
is
not
None
:
controlnet_block_samples
=
torch
.
stack
(
controlnet_block_samples
).
to
(
self
.
device
)
if
controlnet_single_block_samples
is
not
None
:
controlnet_single_block_samples
=
torch
.
stack
(
controlnet_single_block_samples
).
to
(
self
.
device
)
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
shape
[
0
]
==
1
assert
image_rotary_emb
.
shape
[
0
]
==
1
assert
image_rotary_emb
.
shape
[
1
]
==
1
assert
image_rotary_emb
.
shape
[
1
]
==
1
...
@@ -141,7 +159,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -141,7 +159,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_img
,
256
,
1
))
rotary_emb_img
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_img
,
256
,
1
))
hidden_states
,
encoder_hidden_states
=
self
.
m
.
forward_layer
(
hidden_states
,
encoder_hidden_states
=
self
.
m
.
forward_layer
(
idx
,
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
idx
,
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
,
controlnet_block_samples
,
controlnet_single_block_samples
)
)
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
...
@@ -473,3 +492,100 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -473,3 +492,100 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
state_dict
.
update
(
updated_vectors
)
state_dict
.
update
(
updated_vectors
)
self
.
transformer_blocks
[
0
].
m
.
loadDict
(
state_dict
,
True
)
self
.
transformer_blocks
[
0
].
m
.
loadDict
(
state_dict
,
True
)
self
.
reset_x_embedder
()
self
.
reset_x_embedder
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
=
None
,
pooled_projections
:
torch
.
Tensor
=
None
,
timestep
:
torch
.
LongTensor
=
None
,
img_ids
:
torch
.
Tensor
=
None
,
txt_ids
:
torch
.
Tensor
=
None
,
guidance
:
torch
.
Tensor
=
None
,
joint_attention_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
controlnet_block_samples
=
None
,
controlnet_single_block_samples
=
None
,
return_dict
:
bool
=
True
,
controlnet_blocks_repeat
:
bool
=
False
,
)
->
Union
[
torch
.
FloatTensor
,
Transformer2DModelOutput
]:
"""
Copied from diffusers.models.flux.transformer_flux.py
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
hidden_states
=
self
.
x_embedder
(
hidden_states
)
timestep
=
timestep
.
to
(
hidden_states
.
dtype
)
*
1000
if
guidance
is
not
None
:
guidance
=
guidance
.
to
(
hidden_states
.
dtype
)
*
1000
else
:
guidance
=
None
temb
=
(
self
.
time_text_embed
(
timestep
,
pooled_projections
)
if
guidance
is
None
else
self
.
time_text_embed
(
timestep
,
guidance
,
pooled_projections
)
)
encoder_hidden_states
=
self
.
context_embedder
(
encoder_hidden_states
)
if
txt_ids
.
ndim
==
3
:
logger
.
warning
(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids
=
txt_ids
[
0
]
if
img_ids
.
ndim
==
3
:
logger
.
warning
(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids
=
img_ids
[
0
]
ids
=
torch
.
cat
((
txt_ids
,
img_ids
),
dim
=
0
)
image_rotary_emb
=
self
.
pos_embed
(
ids
)
if
joint_attention_kwargs
is
not
None
and
"ip_adapter_image_embeds"
in
joint_attention_kwargs
:
ip_adapter_image_embeds
=
joint_attention_kwargs
.
pop
(
"ip_adapter_image_embeds"
)
ip_hidden_states
=
self
.
encoder_hid_proj
(
ip_adapter_image_embeds
)
joint_attention_kwargs
.
update
({
"ip_hidden_states"
:
ip_hidden_states
})
nunchaku_block
=
self
.
transformer_blocks
[
0
]
encoder_hidden_states
,
hidden_states
=
nunchaku_block
(
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
temb
=
temb
,
image_rotary_emb
=
image_rotary_emb
,
joint_attention_kwargs
=
joint_attention_kwargs
,
controlnet_block_samples
=
controlnet_block_samples
,
controlnet_single_block_samples
=
controlnet_single_block_samples
)
hidden_states
=
torch
.
cat
([
encoder_hidden_states
,
hidden_states
],
dim
=
1
)
hidden_states
=
hidden_states
[:,
encoder_hidden_states
.
shape
[
1
]
:,
...]
hidden_states
=
self
.
norm_out
(
hidden_states
,
temb
)
output
=
self
.
proj_out
(
hidden_states
)
if
not
return_dict
:
return
(
output
,)
return
Transformer2DModelOutput
(
sample
=
output
)
src/FluxModel.cpp
View file @
235238bd
...
@@ -718,7 +718,16 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
...
@@ -718,7 +718,16 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
}
}
}
}
Tensor
FluxModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
,
bool
skip_first_layer
)
{
Tensor
FluxModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
,
bool
skip_first_layer
)
{
const
int
batch_size
=
hidden_states
.
shape
[
0
];
const
int
batch_size
=
hidden_states
.
shape
[
0
];
const
Tensor
::
ScalarType
dtype
=
hidden_states
.
dtype
();
const
Tensor
::
ScalarType
dtype
=
hidden_states
.
dtype
();
const
Device
device
=
hidden_states
.
device
();
const
Device
device
=
hidden_states
.
device
();
...
@@ -727,6 +736,8 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
...
@@ -727,6 +736,8 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
const
int
img_tokens
=
hidden_states
.
shape
[
1
];
const
int
img_tokens
=
hidden_states
.
shape
[
1
];
const
int
numLayers
=
transformer_blocks
.
size
()
+
single_transformer_blocks
.
size
();
const
int
numLayers
=
transformer_blocks
.
size
()
+
single_transformer_blocks
.
size
();
const
int
num_controlnet_block_samples
=
controlnet_block_samples
.
shape
[
0
];
const
int
num_controlnet_single_block_samples
=
controlnet_single_block_samples
.
shape
[
0
];
Tensor
concat
;
Tensor
concat
;
...
@@ -735,6 +746,14 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
...
@@ -735,6 +746,14 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
if
(
size_t
(
layer
)
<
transformer_blocks
.
size
())
{
if
(
size_t
(
layer
)
<
transformer_blocks
.
size
())
{
auto
&
block
=
transformer_blocks
.
at
(
layer
);
auto
&
block
=
transformer_blocks
.
at
(
layer
);
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
if
(
controlnet_block_samples
.
valid
())
{
int
interval_control
=
ceilDiv
(
transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_block_samples
));
int
block_index
=
layer
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states
=
kernels
::
add
(
hidden_states
,
controlnet_block_samples
[
block_index
]);
}
}
else
{
}
else
{
if
(
size_t
(
layer
)
==
transformer_blocks
.
size
())
{
if
(
size_t
(
layer
)
==
transformer_blocks
.
size
())
{
// txt first, same as diffusers
// txt first, same as diffusers
...
@@ -745,10 +764,21 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
...
@@ -745,10 +764,21 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
}
}
hidden_states
=
concat
;
hidden_states
=
concat
;
encoder_hidden_states
=
{};
encoder_hidden_states
=
{};
}
}
auto
&
block
=
single_transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
());
auto
&
block
=
single_transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
());
hidden_states
=
block
->
forward
(
hidden_states
,
temb
,
rotary_emb_single
);
hidden_states
=
block
->
forward
(
hidden_states
,
temb
,
rotary_emb_single
);
if
(
controlnet_single_block_samples
.
valid
())
{
int
interval_control
=
ceilDiv
(
single_transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_single_block_samples
));
int
block_index
=
(
layer
-
transformer_blocks
.
size
())
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto
slice
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
slice
=
kernels
::
add
(
slice
,
controlnet_single_block_samples
[
block_index
]);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
}
}
};
};
auto
load
=
[
&
](
int
layer
)
{
auto
load
=
[
&
](
int
layer
)
{
...
@@ -776,6 +806,50 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
...
@@ -776,6 +806,50 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
return
hidden_states
;
return
hidden_states
;
}
}
std
::
tuple
<
Tensor
,
Tensor
>
FluxModel
::
forward_layer
(
size_t
layer
,
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
)
{
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
transformer_blocks
.
at
(
layer
)
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
const
int
txt_tokens
=
encoder_hidden_states
.
shape
[
1
];
const
int
img_tokens
=
hidden_states
.
shape
[
1
];
const
int
num_controlnet_block_samples
=
controlnet_block_samples
.
shape
[
0
];
const
int
num_controlnet_single_block_samples
=
controlnet_single_block_samples
.
shape
[
0
];
if
(
layer
<
transformer_blocks
.
size
()
&&
controlnet_block_samples
.
valid
())
{
int
interval_control
=
ceilDiv
(
transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_block_samples
));
int
block_index
=
layer
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states
=
kernels
::
add
(
hidden_states
,
controlnet_block_samples
[
block_index
]);
}
else
if
(
layer
>=
transformer_blocks
.
size
()
&&
controlnet_single_block_samples
.
valid
())
{
int
interval_control
=
ceilDiv
(
single_transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_single_block_samples
));
int
block_index
=
(
layer
-
transformer_blocks
.
size
())
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto
slice
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
slice
=
kernels
::
add
(
slice
,
controlnet_single_block_samples
[
block_index
]);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
return
{
hidden_states
,
encoder_hidden_states
};
}
void
FluxModel
::
setAttentionImpl
(
AttentionImpl
impl
)
{
void
FluxModel
::
setAttentionImpl
(
AttentionImpl
impl
)
{
for
(
auto
&&
block
:
this
->
transformer_blocks
)
{
for
(
auto
&&
block
:
this
->
transformer_blocks
)
{
block
->
attnImpl
=
impl
;
block
->
attnImpl
=
impl
;
...
...
src/FluxModel.h
View file @
235238bd
...
@@ -138,8 +138,25 @@ private:
...
@@ -138,8 +138,25 @@ private:
class
FluxModel
:
public
Module
{
class
FluxModel
:
public
Module
{
public:
public:
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
);
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
,
bool
skip_first_layer
=
false
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
,
bool
skip_first_layer
=
false
);
std
::
tuple
<
Tensor
,
Tensor
>
forward_layer
(
size_t
layer
,
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
);
void
setAttentionImpl
(
AttentionImpl
impl
);
void
setAttentionImpl
(
AttentionImpl
impl
);
public:
public:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment