Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
85855252
Unverified
Commit
85855252
authored
Apr 25, 2025
by
Muyang Li
Committed by
GitHub
Apr 25, 2025
Browse files
fix: fix the typo in FluxModel.cpp as in #297 (#317)
* fix: fix a typo * style: format the imports
parent
ccd93d1e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
4 deletions
+5
-4
nunchaku/models/transformers/transformer_flux.py
nunchaku/models/transformers/transformer_flux.py
+4
-3
src/FluxModel.cpp
src/FluxModel.cpp
+1
-1
No files found.
nunchaku/models/transformers/transformer_flux.py
View file @
85855252
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
logging
import
logging
import
os
import
os
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
diffusers
import
diffusers
import
torch
import
torch
...
@@ -10,7 +9,7 @@ from diffusers.configuration_utils import register_to_config
...
@@ -10,7 +9,7 @@ from diffusers.configuration_utils import register_to_config
from
diffusers.models.modeling_outputs
import
Transformer2DModelOutput
from
diffusers.models.modeling_outputs
import
Transformer2DModelOutput
from
huggingface_hub
import
utils
from
huggingface_hub
import
utils
from
packaging.version
import
Version
from
packaging.version
import
Version
from
safetensors.torch
import
load_file
,
save_file
from
safetensors.torch
import
load_file
from
torch
import
nn
from
torch
import
nn
from
.utils
import
NunchakuModelLoaderMixin
,
pad_tensor
from
.utils
import
NunchakuModelLoaderMixin
,
pad_tensor
...
@@ -180,9 +179,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -180,9 +179,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
encoder_hidden_states
=
encoder_hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
encoder_hidden_states
=
encoder_hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
return
encoder_hidden_states
,
hidden_states
return
encoder_hidden_states
,
hidden_states
def
__del__
(
self
):
def
__del__
(
self
):
self
.
m
.
reset
()
self
.
m
.
reset
()
## copied from diffusers 0.30.3
## copied from diffusers 0.30.3
def
rope
(
pos
:
torch
.
Tensor
,
dim
:
int
,
theta
:
int
)
->
torch
.
Tensor
:
def
rope
(
pos
:
torch
.
Tensor
,
dim
:
int
,
theta
:
int
)
->
torch
.
Tensor
:
assert
dim
%
2
==
0
,
"The dimension must be even."
assert
dim
%
2
==
0
,
"The dimension must be even."
...
...
src/FluxModel.cpp
View file @
85855252
...
@@ -526,7 +526,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -526,7 +526,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
?
pool
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
/
POOL_SIZE
)
?
pool
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
/
POOL_SIZE
)
:
Tensor
{};
:
Tensor
{};
Tensor
pool_qkv_context
=
pool
.
valid
()
Tensor
pool_qkv_context
=
pool
.
valid
()
?
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
/
POOL_SIZE
,
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
)
?
pool
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
/
POOL_SIZE
,
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
)
:
Tensor
{};
:
Tensor
{};
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment