Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
235238bd
Commit
235238bd
authored
Apr 02, 2025
by
Hyunsung Lee
Committed by
Zhekai Zhang
Apr 04, 2025
Browse files
Add controlnet
parent
63913f29
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
467 additions
and
66 deletions
+467
-66
examples/controlnet-flux-cache.py
examples/controlnet-flux-cache.py
+60
-0
examples/controlnet-flux.py
examples/controlnet-flux.py
+59
-0
examples/ref-controlnet.py
examples/ref-controlnet.py
+48
-0
nunchaku/csrc/flux.h
nunchaku/csrc/flux.h
+13
-5
nunchaku/csrc/gemm.h
nunchaku/csrc/gemm.h
+5
-5
nunchaku/csrc/pybind.cpp
nunchaku/csrc/pybind.cpp
+25
-6
nunchaku/models/transformers/transformer_flux.py
nunchaku/models/transformers/transformer_flux.py
+118
-2
src/FluxModel.cpp
src/FluxModel.cpp
+117
-43
src/FluxModel.h
src/FluxModel.h
+21
-4
src/interop/torch.h
src/interop/torch.h
+1
-1
No files found.
examples/controlnet-flux-cache.py
0 → 100644
View file @
235238bd
import
random
import
torch
from
diffusers
import
FluxControlNetPipeline
,
FluxControlNetModel
from
diffusers.models
import
FluxMultiControlNetModel
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
diffusers.utils
import
load_image
import
numpy
as
np
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
SEED
=
42
random
.
seed
(
SEED
)
np
.
random
.
seed
(
SEED
)
torch
.
manual_seed
(
SEED
)
torch
.
cuda
.
manual_seed_all
(
SEED
)
base_model
=
'black-forest-labs/FLUX.1-dev'
controlnet_model_union
=
'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union
=
FluxControlNetModel
.
from_pretrained
(
controlnet_model_union
,
torch_dtype
=
torch
.
bfloat16
)
controlnet
=
FluxMultiControlNetModel
([
controlnet_union
])
# we always recommend loading via FluxMultiControlNetModel
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-dev"
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
pipe
=
FluxControlNetPipeline
.
from_pretrained
(
base_model
,
transformer
=
transformer
,
controlnet
=
controlnet
,
torch_dtype
=
torch
.
bfloat16
)
apply_cache_on_pipe
(
pipe
,
residual_diff_threshold
=
0.12
)
pipe
.
to
(
"cuda"
)
prompt
=
'A anime style girl with messy beach waves.'
control_image_depth
=
load_image
(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg"
)
control_mode_depth
=
2
control_image_canny
=
load_image
(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg"
)
control_mode_canny
=
0
width
,
height
=
control_image_depth
.
size
image
=
pipe
(
prompt
,
control_image
=
[
control_image_depth
,
control_image_canny
],
control_mode
=
[
control_mode_depth
,
control_mode_canny
],
width
=
width
,
height
=
height
,
controlnet_conditioning_scale
=
[
0.3
,
0.1
],
num_inference_steps
=
28
,
guidance_scale
=
3.5
,
generator
=
torch
.
manual_seed
(
SEED
),
).
images
[
0
]
image
.
save
(
"nunchaku-controlnet-flux.1-dev.png"
)
examples/controlnet-flux.py
0 → 100644
View file @
235238bd
import
random
import
torch
from
diffusers
import
FluxControlNetPipeline
,
FluxControlNetModel
from
diffusers.models
import
FluxMultiControlNetModel
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
diffusers.utils
import
load_image
import
numpy
as
np
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
SEED
=
42
random
.
seed
(
SEED
)
np
.
random
.
seed
(
SEED
)
torch
.
manual_seed
(
SEED
)
torch
.
cuda
.
manual_seed_all
(
SEED
)
base_model
=
'black-forest-labs/FLUX.1-dev'
controlnet_model_union
=
'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union
=
FluxControlNetModel
.
from_pretrained
(
controlnet_model_union
,
torch_dtype
=
torch
.
bfloat16
)
controlnet
=
FluxMultiControlNetModel
([
controlnet_union
])
# we always recommend loading via FluxMultiControlNetModel
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-dev"
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
pipe
=
FluxControlNetPipeline
.
from_pretrained
(
base_model
,
transformer
=
transformer
,
controlnet
=
controlnet
,
torch_dtype
=
torch
.
bfloat16
)
pipe
.
to
(
"cuda"
)
prompt
=
'A anime style girl with messy beach waves.'
control_image_depth
=
load_image
(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg"
)
control_mode_depth
=
2
control_image_canny
=
load_image
(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg"
)
control_mode_canny
=
0
width
,
height
=
control_image_depth
.
size
image
=
pipe
(
prompt
,
control_image
=
[
control_image_depth
,
control_image_canny
],
control_mode
=
[
control_mode_depth
,
control_mode_canny
],
width
=
width
,
height
=
height
,
controlnet_conditioning_scale
=
[
0.3
,
0.1
],
num_inference_steps
=
28
,
guidance_scale
=
3.5
,
generator
=
torch
.
manual_seed
(
SEED
),
).
images
[
0
]
image
.
save
(
"nunchaku-controlnet-flux.1-dev.png"
)
examples/ref-controlnet.py
0 → 100644
View file @
235238bd
import
random
import
torch
from
diffusers
import
FluxControlNetPipeline
,
FluxControlNetModel
from
diffusers.models
import
FluxMultiControlNetModel
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
diffusers.utils
import
load_image
import
numpy
as
np
SEED
=
42
random
.
seed
(
SEED
)
np
.
random
.
seed
(
SEED
)
torch
.
manual_seed
(
SEED
)
torch
.
cuda
.
manual_seed_all
(
SEED
)
base_model
=
'black-forest-labs/FLUX.1-dev'
controlnet_model_union
=
'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union
=
FluxControlNetModel
.
from_pretrained
(
controlnet_model_union
,
torch_dtype
=
torch
.
bfloat16
)
controlnet
=
FluxMultiControlNetModel
([
controlnet_union
])
# we always recommend loading via FluxMultiControlNetModel
pipe
=
FluxControlNetPipeline
.
from_pretrained
(
base_model
,
controlnet
=
controlnet
,
torch_dtype
=
torch
.
bfloat16
)
pipe
.
to
(
"cuda"
)
prompt
=
'A anime style girl with messy beach waves.'
control_image_depth
=
load_image
(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg"
)
control_mode_depth
=
2
control_image_canny
=
load_image
(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg"
)
control_mode_canny
=
0
width
,
height
=
control_image_depth
.
size
image
=
pipe
(
prompt
,
control_image
=
[
control_image_depth
,
control_image_canny
],
control_mode
=
[
control_mode_depth
,
control_mode_canny
],
width
=
width
,
height
=
height
,
controlnet_conditioning_scale
=
[
0.3
,
0.1
],
num_inference_steps
=
28
,
guidance_scale
=
3.5
,
generator
=
torch
.
manual_seed
(
SEED
),
).
images
[
0
]
image
.
save
(
"reference-controlnet-flux.1-dev.png"
)
nunchaku/csrc/flux.h
View file @
235238bd
...
@@ -35,6 +35,8 @@ public:
...
@@ -35,6 +35,8 @@ public:
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_context
,
torch
::
Tensor
rotary_emb_context
,
torch
::
Tensor
rotary_emb_single
,
torch
::
Tensor
rotary_emb_single
,
std
::
optional
<
torch
::
Tensor
>
controlnet_block_samples
=
std
::
nullopt
,
std
::
optional
<
torch
::
Tensor
>
controlnet_single_block_samples
=
std
::
nullopt
,
bool
skip_first_layer
=
false
)
bool
skip_first_layer
=
false
)
{
{
checkModel
();
checkModel
();
...
@@ -56,6 +58,8 @@ public:
...
@@ -56,6 +58,8 @@ public:
from_torch
(
rotary_emb_img
),
from_torch
(
rotary_emb_img
),
from_torch
(
rotary_emb_context
),
from_torch
(
rotary_emb_context
),
from_torch
(
rotary_emb_single
),
from_torch
(
rotary_emb_single
),
controlnet_block_samples
.
has_value
()
?
from_torch
(
controlnet_block_samples
.
value
().
contiguous
())
:
Tensor
{},
controlnet_single_block_samples
.
has_value
()
?
from_torch
(
controlnet_single_block_samples
.
value
().
contiguous
())
:
Tensor
{},
skip_first_layer
skip_first_layer
);
);
...
@@ -71,7 +75,9 @@ public:
...
@@ -71,7 +75,9 @@ public:
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
temb
,
torch
::
Tensor
temb
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_context
)
torch
::
Tensor
rotary_emb_context
,
std
::
optional
<
torch
::
Tensor
>
controlnet_block_samples
=
std
::
nullopt
,
std
::
optional
<
torch
::
Tensor
>
controlnet_single_block_samples
=
std
::
nullopt
)
{
{
CUDADeviceContext
ctx
(
deviceId
);
CUDADeviceContext
ctx
(
deviceId
);
...
@@ -83,17 +89,19 @@ public:
...
@@ -83,17 +89,19 @@ public:
rotary_emb_img
=
rotary_emb_img
.
contiguous
();
rotary_emb_img
=
rotary_emb_img
.
contiguous
();
rotary_emb_context
=
rotary_emb_context
.
contiguous
();
rotary_emb_context
=
rotary_emb_context
.
contiguous
();
auto
&&
[
result_img
,
result_txt
]
=
net
->
transformer_blocks
.
at
(
idx
)
->
forward
(
auto
&&
[
hidden_states_
,
encoder_hidden_states_
]
=
net
->
forward_layer
(
idx
,
from_torch
(
hidden_states
),
from_torch
(
hidden_states
),
from_torch
(
encoder_hidden_states
),
from_torch
(
encoder_hidden_states
),
from_torch
(
temb
),
from_torch
(
temb
),
from_torch
(
rotary_emb_img
),
from_torch
(
rotary_emb_img
),
from_torch
(
rotary_emb_context
),
from_torch
(
rotary_emb_context
),
0.0
f
controlnet_block_samples
.
has_value
()
?
from_torch
(
controlnet_block_samples
.
value
().
contiguous
())
:
Tensor
{},
controlnet_single_block_samples
.
has_value
()
?
from_torch
(
controlnet_single_block_samples
.
value
().
contiguous
())
:
Tensor
{}
);
);
hidden_states
=
to_torch
(
result_img
);
hidden_states
=
to_torch
(
hidden_states_
);
encoder_hidden_states
=
to_torch
(
result_txt
);
encoder_hidden_states
=
to_torch
(
encoder_hidden_states_
);
Tensor
::
synchronizeDevice
();
Tensor
::
synchronizeDevice
();
return
{
hidden_states
,
encoder_hidden_states
};
return
{
hidden_states
,
encoder_hidden_states
};
...
...
nunchaku/csrc/gemm.h
View file @
235238bd
...
@@ -10,7 +10,7 @@ class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> {
...
@@ -10,7 +10,7 @@ class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> {
public:
public:
void
init
(
int64_t
in_features
,
int64_t
out_features
,
bool
bias
,
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
void
init
(
int64_t
in_features
,
int64_t
out_features
,
bool
bias
,
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedGEMM"
);
spdlog
::
info
(
"Initializing QuantizedGEMM"
);
size_t
val
=
0
;
size_t
val
=
0
;
checkCUDA
(
cudaDeviceSetLimit
(
cudaLimitStackSize
,
8192
));
checkCUDA
(
cudaDeviceSetLimit
(
cudaLimitStackSize
,
8192
));
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
...
@@ -27,7 +27,7 @@ public:
...
@@ -27,7 +27,7 @@ public:
x
=
x
.
contiguous
();
x
=
x
.
contiguous
();
Tensor
result
=
net
->
forward
(
from_torch
(
x
));
Tensor
result
=
net
->
forward
(
from_torch
(
x
));
torch
::
Tensor
output
=
to_torch
(
result
);
torch
::
Tensor
output
=
to_torch
(
result
);
Tensor
::
synchronizeDevice
();
Tensor
::
synchronizeDevice
();
...
@@ -48,7 +48,7 @@ public:
...
@@ -48,7 +48,7 @@ public:
const
int
M
=
x
.
shape
[
0
];
const
int
M
=
x
.
shape
[
0
];
const
int
K
=
x
.
shape
[
1
]
*
2
;
const
int
K
=
x
.
shape
[
1
]
*
2
;
assert
(
x
.
dtype
()
==
Tensor
::
INT8
);
assert
(
x
.
dtype
()
==
Tensor
::
INT8
);
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
...
@@ -83,7 +83,7 @@ public:
...
@@ -83,7 +83,7 @@ public:
}
}
}
}
}
}
ss
<<
std
::
endl
;
ss
<<
std
::
endl
;
return
ss
.
str
();
return
ss
.
str
();
}
}
...
@@ -99,7 +99,7 @@ public:
...
@@ -99,7 +99,7 @@ public:
from_torch
(
x
),
from_torch
(
x
),
fuse_glu
fuse_glu
);
);
Tensor
act
=
qout
.
act
.
copy
(
Device
::
cpu
());
Tensor
act
=
qout
.
act
.
copy
(
Device
::
cpu
());
Tensor
ascales
=
qout
.
ascales
.
copy
(
Device
::
cpu
());
Tensor
ascales
=
qout
.
ascales
.
copy
(
Device
::
cpu
());
Tensor
lora_act
=
qout
.
lora_act
.
copy
(
Device
::
cpu
());
Tensor
lora_act
=
qout
.
lora_act
.
copy
(
Device
::
cpu
());
...
...
nunchaku/csrc/pybind.cpp
View file @
235238bd
...
@@ -18,16 +18,35 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -18,16 +18,35 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"deviceId"
)
py
::
arg
(
"deviceId"
)
)
)
.
def
(
"reset"
,
&
QuantizedFluxModel
::
reset
)
.
def
(
"reset"
,
&
QuantizedFluxModel
::
reset
)
.
def
(
"load"
,
&
QuantizedFluxModel
::
load
,
.
def
(
"load"
,
&
QuantizedFluxModel
::
load
,
py
::
arg
(
"path"
),
py
::
arg
(
"path"
),
py
::
arg
(
"partial"
)
=
false
py
::
arg
(
"partial"
)
=
false
)
)
.
def
(
"loadDict"
,
&
QuantizedFluxModel
::
loadDict
,
.
def
(
"loadDict"
,
&
QuantizedFluxModel
::
loadDict
,
py
::
arg
(
"dict"
),
py
::
arg
(
"dict"
),
py
::
arg
(
"partial"
)
=
false
py
::
arg
(
"partial"
)
=
false
)
)
.
def
(
"forward"
,
&
QuantizedFluxModel
::
forward
)
.
def
(
"forward"
,
&
QuantizedFluxModel
::
forward
,
.
def
(
"forward_layer"
,
&
QuantizedFluxModel
::
forward_layer
)
py
::
arg
(
"hidden_states"
),
py
::
arg
(
"encoder_hidden_states"
),
py
::
arg
(
"temb"
),
py
::
arg
(
"rotary_emb_img"
),
py
::
arg
(
"rotary_emb_context"
),
py
::
arg
(
"rotary_emb_single"
),
py
::
arg
(
"controlnet_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"controlnet_single_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"skip_first_layer"
)
=
false
)
.
def
(
"forward_layer"
,
&
QuantizedFluxModel
::
forward_layer
,
py
::
arg
(
"idx"
),
py
::
arg
(
"hidden_states"
),
py
::
arg
(
"encoder_hidden_states"
),
py
::
arg
(
"temb"
),
py
::
arg
(
"rotary_emb_img"
),
py
::
arg
(
"rotary_emb_context"
),
py
::
arg
(
"controlnet_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"controlnet_single_block_samples"
)
=
py
::
none
()
)
.
def
(
"forward_single_layer"
,
&
QuantizedFluxModel
::
forward_single_layer
)
.
def
(
"forward_single_layer"
,
&
QuantizedFluxModel
::
forward_single_layer
)
.
def
(
"startDebug"
,
&
QuantizedFluxModel
::
startDebug
)
.
def
(
"startDebug"
,
&
QuantizedFluxModel
::
startDebug
)
.
def
(
"stopDebug"
,
&
QuantizedFluxModel
::
stopDebug
)
.
def
(
"stopDebug"
,
&
QuantizedFluxModel
::
stopDebug
)
...
@@ -46,11 +65,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -46,11 +65,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"deviceId"
)
py
::
arg
(
"deviceId"
)
)
)
.
def
(
"reset"
,
&
QuantizedSanaModel
::
reset
)
.
def
(
"reset"
,
&
QuantizedSanaModel
::
reset
)
.
def
(
"load"
,
&
QuantizedSanaModel
::
load
,
.
def
(
"load"
,
&
QuantizedSanaModel
::
load
,
py
::
arg
(
"path"
),
py
::
arg
(
"path"
),
py
::
arg
(
"partial"
)
=
false
py
::
arg
(
"partial"
)
=
false
)
)
.
def
(
"loadDict"
,
&
QuantizedSanaModel
::
loadDict
,
.
def
(
"loadDict"
,
&
QuantizedSanaModel
::
loadDict
,
py
::
arg
(
"dict"
),
py
::
arg
(
"dict"
),
py
::
arg
(
"partial"
)
=
false
py
::
arg
(
"partial"
)
=
false
)
)
...
...
nunchaku/models/transformers/transformer_flux.py
View file @
235238bd
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
logging
import
logging
import
os
import
os
...
@@ -5,6 +7,7 @@ import diffusers
...
@@ -5,6 +7,7 @@ import diffusers
import
torch
import
torch
from
diffusers
import
FluxTransformer2DModel
from
diffusers
import
FluxTransformer2DModel
from
diffusers.configuration_utils
import
register_to_config
from
diffusers.configuration_utils
import
register_to_config
from
diffusers.models.modeling_outputs
import
Transformer2DModelOutput
from
huggingface_hub
import
utils
from
huggingface_hub
import
utils
from
packaging.version
import
Version
from
packaging.version
import
Version
from
safetensors.torch
import
load_file
,
save_file
from
safetensors.torch
import
load_file
,
save_file
...
@@ -62,6 +65,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -62,6 +65,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
encoder_hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
image_rotary_emb
:
torch
.
Tensor
,
image_rotary_emb
:
torch
.
Tensor
,
joint_attention_kwargs
=
None
,
joint_attention_kwargs
=
None
,
controlnet_block_samples
=
None
,
controlnet_single_block_samples
=
None
,
skip_first_layer
=
False
,
skip_first_layer
=
False
,
):
):
batch_size
=
hidden_states
.
shape
[
0
]
batch_size
=
hidden_states
.
shape
[
0
]
...
@@ -76,6 +81,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -76,6 +81,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb
=
temb
.
to
(
self
.
dtype
).
to
(
self
.
device
)
temb
=
temb
.
to
(
self
.
dtype
).
to
(
self
.
device
)
image_rotary_emb
=
image_rotary_emb
.
to
(
self
.
device
)
image_rotary_emb
=
image_rotary_emb
.
to
(
self
.
device
)
if
controlnet_block_samples
is
not
None
:
controlnet_block_samples
=
torch
.
stack
(
controlnet_block_samples
).
to
(
self
.
device
)
if
controlnet_single_block_samples
is
not
None
:
controlnet_single_block_samples
=
torch
.
stack
(
controlnet_single_block_samples
).
to
(
self
.
device
)
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
shape
[
0
]
==
1
assert
image_rotary_emb
.
shape
[
0
]
==
1
assert
image_rotary_emb
.
shape
[
1
]
==
1
assert
image_rotary_emb
.
shape
[
1
]
==
1
...
@@ -89,7 +99,6 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -89,7 +99,6 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_txt
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_txt
,
256
,
1
))
rotary_emb_txt
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_txt
,
256
,
1
))
rotary_emb_img
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_img
,
256
,
1
))
rotary_emb_img
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_img
,
256
,
1
))
rotary_emb_single
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_single
,
256
,
1
))
rotary_emb_single
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_single
,
256
,
1
))
hidden_states
=
self
.
m
.
forward
(
hidden_states
=
self
.
m
.
forward
(
hidden_states
,
hidden_states
,
encoder_hidden_states
,
encoder_hidden_states
,
...
@@ -97,6 +106,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -97,6 +106,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img
,
rotary_emb_img
,
rotary_emb_txt
,
rotary_emb_txt
,
rotary_emb_single
,
rotary_emb_single
,
controlnet_block_samples
,
controlnet_single_block_samples
,
skip_first_layer
,
skip_first_layer
,
)
)
...
@@ -115,6 +126,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -115,6 +126,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb
:
torch
.
Tensor
,
temb
:
torch
.
Tensor
,
image_rotary_emb
:
torch
.
Tensor
,
image_rotary_emb
:
torch
.
Tensor
,
joint_attention_kwargs
=
None
,
joint_attention_kwargs
=
None
,
controlnet_block_samples
=
None
,
controlnet_single_block_samples
=
None
):
):
batch_size
=
hidden_states
.
shape
[
0
]
batch_size
=
hidden_states
.
shape
[
0
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
...
@@ -128,6 +141,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -128,6 +141,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb
=
temb
.
to
(
self
.
dtype
).
to
(
self
.
device
)
temb
=
temb
.
to
(
self
.
dtype
).
to
(
self
.
device
)
image_rotary_emb
=
image_rotary_emb
.
to
(
self
.
device
)
image_rotary_emb
=
image_rotary_emb
.
to
(
self
.
device
)
if
controlnet_block_samples
is
not
None
:
controlnet_block_samples
=
torch
.
stack
(
controlnet_block_samples
).
to
(
self
.
device
)
if
controlnet_single_block_samples
is
not
None
:
controlnet_single_block_samples
=
torch
.
stack
(
controlnet_single_block_samples
).
to
(
self
.
device
)
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
shape
[
0
]
==
1
assert
image_rotary_emb
.
shape
[
0
]
==
1
assert
image_rotary_emb
.
shape
[
1
]
==
1
assert
image_rotary_emb
.
shape
[
1
]
==
1
...
@@ -141,7 +159,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -141,7 +159,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_img
,
256
,
1
))
rotary_emb_img
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_img
,
256
,
1
))
hidden_states
,
encoder_hidden_states
=
self
.
m
.
forward_layer
(
hidden_states
,
encoder_hidden_states
=
self
.
m
.
forward_layer
(
idx
,
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
idx
,
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
,
controlnet_block_samples
,
controlnet_single_block_samples
)
)
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
...
@@ -473,3 +492,100 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -473,3 +492,100 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
state_dict
.
update
(
updated_vectors
)
state_dict
.
update
(
updated_vectors
)
self
.
transformer_blocks
[
0
].
m
.
loadDict
(
state_dict
,
True
)
self
.
transformer_blocks
[
0
].
m
.
loadDict
(
state_dict
,
True
)
self
.
reset_x_embedder
()
self
.
reset_x_embedder
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
=
None
,
pooled_projections
:
torch
.
Tensor
=
None
,
timestep
:
torch
.
LongTensor
=
None
,
img_ids
:
torch
.
Tensor
=
None
,
txt_ids
:
torch
.
Tensor
=
None
,
guidance
:
torch
.
Tensor
=
None
,
joint_attention_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
controlnet_block_samples
=
None
,
controlnet_single_block_samples
=
None
,
return_dict
:
bool
=
True
,
controlnet_blocks_repeat
:
bool
=
False
,
)
->
Union
[
torch
.
FloatTensor
,
Transformer2DModelOutput
]:
"""
Copied from diffusers.models.flux.transformer_flux.py
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
hidden_states
=
self
.
x_embedder
(
hidden_states
)
timestep
=
timestep
.
to
(
hidden_states
.
dtype
)
*
1000
if
guidance
is
not
None
:
guidance
=
guidance
.
to
(
hidden_states
.
dtype
)
*
1000
else
:
guidance
=
None
temb
=
(
self
.
time_text_embed
(
timestep
,
pooled_projections
)
if
guidance
is
None
else
self
.
time_text_embed
(
timestep
,
guidance
,
pooled_projections
)
)
encoder_hidden_states
=
self
.
context_embedder
(
encoder_hidden_states
)
if
txt_ids
.
ndim
==
3
:
logger
.
warning
(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids
=
txt_ids
[
0
]
if
img_ids
.
ndim
==
3
:
logger
.
warning
(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids
=
img_ids
[
0
]
ids
=
torch
.
cat
((
txt_ids
,
img_ids
),
dim
=
0
)
image_rotary_emb
=
self
.
pos_embed
(
ids
)
if
joint_attention_kwargs
is
not
None
and
"ip_adapter_image_embeds"
in
joint_attention_kwargs
:
ip_adapter_image_embeds
=
joint_attention_kwargs
.
pop
(
"ip_adapter_image_embeds"
)
ip_hidden_states
=
self
.
encoder_hid_proj
(
ip_adapter_image_embeds
)
joint_attention_kwargs
.
update
({
"ip_hidden_states"
:
ip_hidden_states
})
nunchaku_block
=
self
.
transformer_blocks
[
0
]
encoder_hidden_states
,
hidden_states
=
nunchaku_block
(
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
temb
=
temb
,
image_rotary_emb
=
image_rotary_emb
,
joint_attention_kwargs
=
joint_attention_kwargs
,
controlnet_block_samples
=
controlnet_block_samples
,
controlnet_single_block_samples
=
controlnet_single_block_samples
)
hidden_states
=
torch
.
cat
([
encoder_hidden_states
,
hidden_states
],
dim
=
1
)
hidden_states
=
hidden_states
[:,
encoder_hidden_states
.
shape
[
1
]
:,
...]
hidden_states
=
self
.
norm_out
(
hidden_states
,
temb
)
output
=
self
.
proj_out
(
hidden_states
)
if
not
return_dict
:
return
(
output
,)
return
Transformer2DModelOutput
(
sample
=
output
)
src/FluxModel.cpp
View file @
235238bd
...
@@ -40,7 +40,7 @@ Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
...
@@ -40,7 +40,7 @@ Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
AdaLayerNormZeroSingle
::
AdaLayerNormZeroSingle
(
int
dim
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
AdaLayerNormZeroSingle
::
AdaLayerNormZeroSingle
(
int
dim
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim
(
dim
),
linear
(
dim
,
3
*
dim
,
true
,
dtype
,
device
),
linear
(
dim
,
3
*
dim
,
true
,
dtype
,
device
),
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
{
{
registerChildren
registerChildren
(
linear
,
"linear"
)
(
linear
,
"linear"
)
...
@@ -59,12 +59,12 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
...
@@ -59,12 +59,12 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
debug
(
"x"
,
x
);
debug
(
"x"
,
x
);
Tensor
norm_x
=
norm
.
forward
(
x
);
Tensor
norm_x
=
norm
.
forward
(
x
);
debug
(
"norm_x"
,
norm_x
);
debug
(
"norm_x"
,
norm_x
);
kernels
::
mul_add
(
norm_x
,
scale_msa
,
shift_msa
);
kernels
::
mul_add
(
norm_x
,
scale_msa
,
shift_msa
);
return
Output
{
norm_x
,
gate_msa
};
return
Output
{
norm_x
,
gate_msa
};
}
}
AdaLayerNormZero
::
AdaLayerNormZero
(
int
dim
,
bool
pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
AdaLayerNormZero
::
AdaLayerNormZero
(
int
dim
,
bool
pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
pre_only
(
pre_only
),
dim
(
dim
),
pre_only
(
pre_only
),
linear
(
dim
,
pre_only
?
2
*
dim
:
6
*
dim
,
true
,
dtype
,
device
),
linear
(
dim
,
pre_only
?
2
*
dim
:
6
*
dim
,
true
,
dtype
,
device
),
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
...
@@ -91,7 +91,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
...
@@ -91,7 +91,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
kernels
::
mul_add
(
norm_x
,
scale_msa
,
shift_msa
);
kernels
::
mul_add
(
norm_x
,
scale_msa
,
shift_msa
);
debug
(
"norm_x_scaled"
,
norm_x
);
debug
(
"norm_x_scaled"
,
norm_x
);
return
Output
{
norm_x
};
return
Output
{
norm_x
};
}
else
{
}
else
{
auto
&&
[
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
]
=
kernels
::
split_mod
<
6
>
(
emb
);
auto
&&
[
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
]
=
kernels
::
split_mod
<
6
>
(
emb
);
...
@@ -108,7 +108,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
...
@@ -108,7 +108,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
}
}
Attention
::
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
)
:
Attention
::
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
)
:
num_heads
(
num_heads
),
dim_head
(
dim_head
),
force_fp16
(
false
)
num_heads
(
num_heads
),
dim_head
(
dim_head
),
force_fp16
(
false
)
{
{
headmask_type
=
Tensor
::
allocate
({
num_heads
},
Tensor
::
INT32
,
Device
::
cpu
());
headmask_type
=
Tensor
::
allocate
({
num_heads
},
Tensor
::
INT32
,
Device
::
cpu
());
...
@@ -151,7 +151,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
...
@@ -151,7 +151,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
gemm_batched_fp16
(
pool_q
,
pool_k
,
pool_s
);
gemm_batched_fp16
(
pool_q
,
pool_k
,
pool_s
);
}
}
}
}
blockmask
=
kernels
::
topk
(
pool_score
,
pool_tokens
*
(
1
-
sparsityRatio
));
blockmask
=
kernels
::
topk
(
pool_score
,
pool_tokens
*
(
1
-
sparsityRatio
));
if
(
cu_seqlens_cpu
.
valid
())
{
if
(
cu_seqlens_cpu
.
valid
())
{
...
@@ -227,9 +227,9 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
...
@@ -227,9 +227,9 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
false
false
).front();
).front();
Tensor raw_attn_output = mha_fwd(q, k, v,
Tensor raw_attn_output = mha_fwd(q, k, v,
0.0f,
0.0f,
pow(q.shape[-1], (-0.5)),
pow(q.shape[-1], (-0.5)),
false, -1, -1, false
false, -1, -1, false
).front();
).front();
...
@@ -261,7 +261,7 @@ void Attention::setForceFP16(Module *module, bool value) {
...
@@ -261,7 +261,7 @@ void Attention::setForceFP16(Module *module, bool value) {
}
}
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
num_heads
(
num_attention_heads
),
mlp_hidden_dim
(
dim
*
mlp_ratio
),
mlp_hidden_dim
(
dim
*
mlp_ratio
),
...
@@ -311,7 +311,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
...
@@ -311,7 +311,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
qkv_proj
.
forward
(
norm_hidden_states
,
qkv
,
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
qkv_proj
.
forward
(
norm_hidden_states
,
qkv
,
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
debug
(
"qkv"
,
qkv
);
debug
(
"qkv"
,
qkv
);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
attn_output
=
attn
.
forward
(
qkv
,
{},
0
);
attn_output
=
attn
.
forward
(
qkv
,
{},
0
);
attn_output
=
attn_output
.
reshape
({
batch_size
,
num_tokens
,
num_heads
*
dim_head
});
attn_output
=
attn_output
.
reshape
({
batch_size
,
num_tokens
,
num_heads
*
dim_head
});
}
else
if
(
attnImpl
==
AttentionImpl
::
NunchakuFP16
)
{
}
else
if
(
attnImpl
==
AttentionImpl
::
NunchakuFP16
)
{
...
@@ -340,7 +340,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
...
@@ -340,7 +340,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
debug
(
"raw_attn_output"
,
attn_output
);
debug
(
"raw_attn_output"
,
attn_output
);
attn_output
=
forward_fc
(
out_proj
,
attn_output
);
attn_output
=
forward_fc
(
out_proj
,
attn_output
);
debug
(
"attn_output"
,
attn_output
);
debug
(
"attn_output"
,
attn_output
);
...
@@ -350,7 +350,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
...
@@ -350,7 +350,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
hidden_states
=
kernels
::
add
(
attn_output
,
ff_output
);
hidden_states
=
kernels
::
add
(
attn_output
,
ff_output
);
debug
(
"attn_ff_output"
,
hidden_states
);
debug
(
"attn_ff_output"
,
hidden_states
);
kernels
::
mul_add
(
hidden_states
,
gate
,
residual
);
kernels
::
mul_add
(
hidden_states
,
gate
,
residual
);
nvtxRangePop
();
nvtxRangePop
();
...
@@ -358,7 +358,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
...
@@ -358,7 +358,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
return
hidden_states
;
return
hidden_states
;
}
}
JointTransformerBlock
::
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
JointTransformerBlock
::
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
num_heads
(
num_attention_heads
),
...
@@ -416,7 +416,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -416,7 +416,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
int
num_tokens_img
=
hidden_states
.
shape
[
1
];
int
num_tokens_img
=
hidden_states
.
shape
[
1
];
int
num_tokens_txt
=
encoder_hidden_states
.
shape
[
1
];
int
num_tokens_txt
=
encoder_hidden_states
.
shape
[
1
];
assert
(
hidden_states
.
shape
[
2
]
==
dim
);
assert
(
hidden_states
.
shape
[
2
]
==
dim
);
assert
(
encoder_hidden_states
.
shape
[
2
]
==
dim
);
assert
(
encoder_hidden_states
.
shape
[
2
]
==
dim
);
...
@@ -439,7 +439,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -439,7 +439,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop
();
nvtxRangePop
();
auto
stream
=
getCurrentCUDAStream
();
auto
stream
=
getCurrentCUDAStream
();
int
num_tokens_img_pad
=
0
,
num_tokens_txt_pad
=
0
;
int
num_tokens_img_pad
=
0
,
num_tokens_txt_pad
=
0
;
Tensor
raw_attn_output
;
Tensor
raw_attn_output
;
...
@@ -449,66 +449,66 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -449,66 +449,66 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor
concat
;
Tensor
concat
;
Tensor
pool
;
Tensor
pool
;
{
{
nvtxRangePushA
(
"qkv_proj"
);
nvtxRangePushA
(
"qkv_proj"
);
const
bool
blockSparse
=
sparsityRatio
>
0
;
const
bool
blockSparse
=
sparsityRatio
>
0
;
const
int
poolTokens
=
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
;
const
int
poolTokens
=
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
;
concat
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
+
num_tokens_txt
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
concat
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
+
num_tokens_txt
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
pool
=
blockSparse
pool
=
blockSparse
?
Tensor
::
allocate
({
batch_size
,
poolTokens
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
())
?
Tensor
::
allocate
({
batch_size
,
poolTokens
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
())
:
Tensor
{};
:
Tensor
{};
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
// img first
// img first
Tensor
qkv
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
);
Tensor
qkv
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
);
Tensor
qkv_context
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
,
num_tokens_img
+
num_tokens_txt
);
Tensor
qkv_context
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
,
num_tokens_img
+
num_tokens_txt
);
Tensor
pool_qkv
=
pool
.
valid
()
Tensor
pool_qkv
=
pool
.
valid
()
?
pool
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
/
POOL_SIZE
)
?
pool
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
/
POOL_SIZE
)
:
Tensor
{};
:
Tensor
{};
Tensor
pool_qkv_context
=
pool
.
valid
()
Tensor
pool_qkv_context
=
pool
.
valid
()
?
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
/
POOL_SIZE
,
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
)
?
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
/
POOL_SIZE
,
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
)
:
Tensor
{};
:
Tensor
{};
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
// debug("qkv_raw", qkv);
// debug("qkv_raw", qkv);
debug
(
"rotary_emb"
,
rotary_emb
);
debug
(
"rotary_emb"
,
rotary_emb
);
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv
,
pool_qkv
,
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv
,
pool_qkv
,
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
debug
(
"qkv"
,
qkv
);
debug
(
"qkv"
,
qkv
);
// qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
// qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
// debug("qkv_context_raw", qkv_context);
// debug("qkv_context_raw", qkv_context);
debug
(
"rotary_emb_context"
,
rotary_emb_context
);
debug
(
"rotary_emb_context"
,
rotary_emb_context
);
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv_context
,
pool_qkv_context
,
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
);
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv_context
,
pool_qkv_context
,
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
);
debug
(
"qkv_context"
,
qkv_context
);
debug
(
"qkv_context"
,
qkv_context
);
}
}
nvtxRangePop
();
nvtxRangePop
();
}
}
spdlog
::
debug
(
"concat={}"
,
concat
.
shape
.
str
());
spdlog
::
debug
(
"concat={}"
,
concat
.
shape
.
str
());
debug
(
"concat"
,
concat
);
debug
(
"concat"
,
concat
);
assert
(
concat
.
shape
[
2
]
==
num_heads
*
dim_head
*
3
);
assert
(
concat
.
shape
[
2
]
==
num_heads
*
dim_head
*
3
);
nvtxRangePushA
(
"Attention"
);
nvtxRangePushA
(
"Attention"
);
raw_attn_output
=
attn
.
forward
(
concat
,
pool
,
sparsityRatio
);
raw_attn_output
=
attn
.
forward
(
concat
,
pool
,
sparsityRatio
);
nvtxRangePop
();
nvtxRangePop
();
spdlog
::
debug
(
"raw_attn_output={}"
,
raw_attn_output
.
shape
.
str
());
spdlog
::
debug
(
"raw_attn_output={}"
,
raw_attn_output
.
shape
.
str
());
raw_attn_output
=
raw_attn_output
.
view
({
batch_size
,
num_tokens_img
+
num_tokens_txt
,
num_heads
,
dim_head
});
raw_attn_output
=
raw_attn_output
.
view
({
batch_size
,
num_tokens_img
+
num_tokens_txt
,
num_heads
,
dim_head
});
}
else
if
(
attnImpl
==
AttentionImpl
::
NunchakuFP16
)
{
}
else
if
(
attnImpl
==
AttentionImpl
::
NunchakuFP16
)
{
num_tokens_img_pad
=
ceilDiv
(
num_tokens_img
,
256
)
*
256
;
num_tokens_img_pad
=
ceilDiv
(
num_tokens_img
,
256
)
*
256
;
num_tokens_txt_pad
=
ceilDiv
(
num_tokens_txt
,
256
)
*
256
;
num_tokens_txt_pad
=
ceilDiv
(
num_tokens_txt
,
256
)
*
256
;
...
@@ -517,11 +517,11 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -517,11 +517,11 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
{
{
nvtxRangePushA
(
"qkv_proj"
);
nvtxRangePushA
(
"qkv_proj"
);
concat_q
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
dim_head
},
Tensor
::
FP16
,
norm1_output
.
x
.
device
());
concat_q
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
dim_head
},
Tensor
::
FP16
,
norm1_output
.
x
.
device
());
concat_k
=
Tensor
::
empty_like
(
concat_q
);
concat_k
=
Tensor
::
empty_like
(
concat_q
);
concat_v
=
Tensor
::
empty_like
(
concat_q
);
concat_v
=
Tensor
::
empty_like
(
concat_q
);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
// img first
// img first
auto
sliceImg
=
[
&
](
Tensor
x
)
{
auto
sliceImg
=
[
&
](
Tensor
x
)
{
...
@@ -530,12 +530,12 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -530,12 +530,12 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
auto
sliceTxt
=
[
&
](
Tensor
x
)
{
auto
sliceTxt
=
[
&
](
Tensor
x
)
{
return
x
.
slice
(
0
,
i
,
i
+
1
).
slice
(
2
,
num_tokens_img_pad
,
num_tokens_img_pad
+
num_tokens_txt_pad
);
return
x
.
slice
(
0
,
i
,
i
+
1
).
slice
(
2
,
num_tokens_img_pad
,
num_tokens_img_pad
+
num_tokens_txt_pad
);
};
};
qkv_proj
.
forward
(
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
sliceImg
(
concat_q
),
sliceImg
(
concat_k
),
sliceImg
(
concat_v
),
num_tokens_img
sliceImg
(
concat_q
),
sliceImg
(
concat_k
),
sliceImg
(
concat_v
),
num_tokens_img
);
);
qkv_proj_context
.
forward
(
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
,
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
,
sliceTxt
(
concat_q
),
sliceTxt
(
concat_k
),
sliceTxt
(
concat_v
),
num_tokens_txt
sliceTxt
(
concat_q
),
sliceTxt
(
concat_k
),
sliceTxt
(
concat_v
),
num_tokens_txt
...
@@ -545,7 +545,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -545,7 +545,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug
(
"concat_q"
,
concat_q
);
debug
(
"concat_q"
,
concat_q
);
debug
(
"concat_k"
,
concat_k
);
debug
(
"concat_k"
,
concat_k
);
debug
(
"concat_v"
,
concat_v
);
debug
(
"concat_v"
,
concat_v
);
nvtxRangePop
();
nvtxRangePop
();
}
}
...
@@ -718,7 +718,16 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
...
@@ -718,7 +718,16 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
}
}
}
}
Tensor
FluxModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
,
bool
skip_first_layer
)
{
Tensor
FluxModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
,
bool
skip_first_layer
)
{
const
int
batch_size
=
hidden_states
.
shape
[
0
];
const
int
batch_size
=
hidden_states
.
shape
[
0
];
const
Tensor
::
ScalarType
dtype
=
hidden_states
.
dtype
();
const
Tensor
::
ScalarType
dtype
=
hidden_states
.
dtype
();
const
Device
device
=
hidden_states
.
device
();
const
Device
device
=
hidden_states
.
device
();
...
@@ -727,6 +736,8 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
...
@@ -727,6 +736,8 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
const
int
img_tokens
=
hidden_states
.
shape
[
1
];
const
int
img_tokens
=
hidden_states
.
shape
[
1
];
const
int
numLayers
=
transformer_blocks
.
size
()
+
single_transformer_blocks
.
size
();
const
int
numLayers
=
transformer_blocks
.
size
()
+
single_transformer_blocks
.
size
();
const
int
num_controlnet_block_samples
=
controlnet_block_samples
.
shape
[
0
];
const
int
num_controlnet_single_block_samples
=
controlnet_single_block_samples
.
shape
[
0
];
Tensor
concat
;
Tensor
concat
;
...
@@ -735,6 +746,14 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
...
@@ -735,6 +746,14 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
if
(
size_t
(
layer
)
<
transformer_blocks
.
size
())
{
if
(
size_t
(
layer
)
<
transformer_blocks
.
size
())
{
auto
&
block
=
transformer_blocks
.
at
(
layer
);
auto
&
block
=
transformer_blocks
.
at
(
layer
);
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
if
(
controlnet_block_samples
.
valid
())
{
int
interval_control
=
ceilDiv
(
transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_block_samples
));
int
block_index
=
layer
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states
=
kernels
::
add
(
hidden_states
,
controlnet_block_samples
[
block_index
]);
}
}
else
{
}
else
{
if
(
size_t
(
layer
)
==
transformer_blocks
.
size
())
{
if
(
size_t
(
layer
)
==
transformer_blocks
.
size
())
{
// txt first, same as diffusers
// txt first, same as diffusers
...
@@ -745,10 +764,21 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
...
@@ -745,10 +764,21 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
}
}
hidden_states
=
concat
;
hidden_states
=
concat
;
encoder_hidden_states
=
{};
encoder_hidden_states
=
{};
}
}
auto
&
block
=
single_transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
());
auto
&
block
=
single_transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
());
hidden_states
=
block
->
forward
(
hidden_states
,
temb
,
rotary_emb_single
);
hidden_states
=
block
->
forward
(
hidden_states
,
temb
,
rotary_emb_single
);
if
(
controlnet_single_block_samples
.
valid
())
{
int
interval_control
=
ceilDiv
(
single_transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_single_block_samples
));
int
block_index
=
(
layer
-
transformer_blocks
.
size
())
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto
slice
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
slice
=
kernels
::
add
(
slice
,
controlnet_single_block_samples
[
block_index
]);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
}
}
};
};
auto
load
=
[
&
](
int
layer
)
{
auto
load
=
[
&
](
int
layer
)
{
...
@@ -776,6 +806,50 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
...
@@ -776,6 +806,50 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
return
hidden_states
;
return
hidden_states
;
}
}
std
::
tuple
<
Tensor
,
Tensor
>
FluxModel
::
forward_layer
(
size_t
layer
,
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
)
{
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
transformer_blocks
.
at
(
layer
)
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
const
int
txt_tokens
=
encoder_hidden_states
.
shape
[
1
];
const
int
img_tokens
=
hidden_states
.
shape
[
1
];
const
int
num_controlnet_block_samples
=
controlnet_block_samples
.
shape
[
0
];
const
int
num_controlnet_single_block_samples
=
controlnet_single_block_samples
.
shape
[
0
];
if
(
layer
<
transformer_blocks
.
size
()
&&
controlnet_block_samples
.
valid
())
{
int
interval_control
=
ceilDiv
(
transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_block_samples
));
int
block_index
=
layer
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states
=
kernels
::
add
(
hidden_states
,
controlnet_block_samples
[
block_index
]);
}
else
if
(
layer
>=
transformer_blocks
.
size
()
&&
controlnet_single_block_samples
.
valid
())
{
int
interval_control
=
ceilDiv
(
single_transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_single_block_samples
));
int
block_index
=
(
layer
-
transformer_blocks
.
size
())
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto
slice
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
slice
=
kernels
::
add
(
slice
,
controlnet_single_block_samples
[
block_index
]);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
return
{
hidden_states
,
encoder_hidden_states
};
}
void
FluxModel
::
setAttentionImpl
(
AttentionImpl
impl
)
{
void
FluxModel
::
setAttentionImpl
(
AttentionImpl
impl
)
{
for
(
auto
&&
block
:
this
->
transformer_blocks
)
{
for
(
auto
&&
block
:
this
->
transformer_blocks
)
{
block
->
attnImpl
=
impl
;
block
->
attnImpl
=
impl
;
...
...
src/FluxModel.h
View file @
235238bd
...
@@ -61,7 +61,7 @@ private:
...
@@ -61,7 +61,7 @@ private:
class
Attention
:
public
Module
{
class
Attention
:
public
Module
{
public:
public:
static
constexpr
int
POOL_SIZE
=
128
;
static
constexpr
int
POOL_SIZE
=
128
;
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
);
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
);
Tensor
forward
(
Tensor
qkv
,
Tensor
pool_qkv
,
float
sparsityRatio
);
Tensor
forward
(
Tensor
qkv
,
Tensor
pool_qkv
,
float
sparsityRatio
);
...
@@ -138,13 +138,30 @@ private:
...
@@ -138,13 +138,30 @@ private:
class
FluxModel
:
public
Module
{
class
FluxModel
:
public
Module
{
public:
public:
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
);
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
,
bool
skip_first_layer
=
false
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
,
bool
skip_first_layer
=
false
);
std
::
tuple
<
Tensor
,
Tensor
>
forward_layer
(
size_t
layer
,
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
);
void
setAttentionImpl
(
AttentionImpl
impl
);
void
setAttentionImpl
(
AttentionImpl
impl
);
public:
public:
const
Tensor
::
ScalarType
dtype
;
const
Tensor
::
ScalarType
dtype
;
std
::
vector
<
std
::
unique_ptr
<
JointTransformerBlock
>>
transformer_blocks
;
std
::
vector
<
std
::
unique_ptr
<
JointTransformerBlock
>>
transformer_blocks
;
std
::
vector
<
std
::
unique_ptr
<
FluxSingleTransformerBlock
>>
single_transformer_blocks
;
std
::
vector
<
std
::
unique_ptr
<
FluxSingleTransformerBlock
>>
single_transformer_blocks
;
...
...
src/interop/torch.h
View file @
235238bd
...
@@ -13,7 +13,7 @@ public:
...
@@ -13,7 +13,7 @@ public:
this
->
device
.
type
=
this
->
tensor
.
is_cuda
()
?
Device
::
CUDA
:
Device
::
CPU
;
this
->
device
.
type
=
this
->
tensor
.
is_cuda
()
?
Device
::
CUDA
:
Device
::
CPU
;
this
->
device
.
idx
=
this
->
tensor
.
get_device
();
this
->
device
.
idx
=
this
->
tensor
.
get_device
();
}
}
virtual
bool
isAsyncBuffer
()
override
{
virtual
bool
isAsyncBuffer
()
override
{
// TODO: figure out how torch manages memory
// TODO: figure out how torch manages memory
return
this
->
device
.
type
==
Device
::
CUDA
;
return
this
->
device
.
type
==
Device
::
CUDA
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment