Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e47baefc
Commit
e47baefc
authored
Jan 09, 2026
by
renzhc
Browse files
fixed shape error
parent
c6714fc3
Pipeline
#3195
failed with stages
in 0 seconds
Changes
2
Pipelines
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
92 additions
and
70 deletions
+92
-70
src/diffusers/models/attention_processor.py
src/diffusers/models/attention_processor.py
+47
-29
src/diffusers/models/model_loading_utils.py
src/diffusers/models/model_loading_utils.py
+45
-41
No files found.
src/diffusers/models/attention_processor.py
View file @
e47baefc
...
...
@@ -2034,15 +2034,29 @@ class AllegroAttnProcessor2_0:
if
attn
.
group_norm
is
not
None
:
hidden_states
=
attn
.
group_norm
(
hidden_states
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
)
query
=
attn
.
to_q
(
hidden_states
)
# query = attn.to_q(hidden_states)
# DCU OPT: TN->NN
if
attn
.
to_q
.
bias
:
query
=
torch
.
matmul
(
hidden_states
,
attn
.
to_q
.
weight
.
data
)
+
attn
.
to_q
.
bias
.
data
else
:
query
=
torch
.
matmul
(
hidden_states
,
attn
.
to_q
.
weight
.
data
)
if
encoder_hidden_states
is
None
:
encoder_hidden_states
=
hidden_states
elif
attn
.
norm_cross
:
encoder_hidden_states
=
attn
.
norm_encoder_hidden_states
(
encoder_hidden_states
)
key
=
attn
.
to_k
(
encoder_hidden_states
)
value
=
attn
.
to_v
(
encoder_hidden_states
)
# key = attn.to_k(encoder_hidden_states)
# value = attn.to_v(encoder_hidden_states)
# DCU OPT: TN->NN
if
attn
.
to_k
.
bias
:
key
=
torch
.
matmul
(
encoder_hidden_states
,
attn
.
to_k
.
weight
.
data
)
+
attn
.
to_k
.
bias
.
data
else
:
key
=
torch
.
matmul
(
encoder_hidden_states
,
attn
.
to_k
.
weight
.
data
)
if
attn
.
to_v
.
bias
:
value
=
torch
.
matmul
(
encoder_hidden_states
,
attn
.
to_v
.
weight
.
data
)
+
attn
.
to_v
.
bias
.
data
else
:
value
=
torch
.
matmul
(
encoder_hidden_states
,
attn
.
to_v
.
weight
.
data
)
inner_dim
=
key
.
shape
[
-
1
]
head_dim
=
inner_dim
//
attn
.
heads
...
...
@@ -2068,7 +2082,9 @@ class AllegroAttnProcessor2_0:
hidden_states
=
hidden_states
.
to
(
query
.
dtype
)
# linear proj
hidden_states
=
attn
.
to_out
[
0
](
hidden_states
)
# hidden_states = attn.to_out[0](hidden_states)
# DCU OPT: TN->NN
hidden_states
=
torch
.
matmul
(
hidden_states
,
attn
.
to_out
[
0
].
weight
.
data
)
+
attn
.
to_out
[
0
].
bias
.
data
# dropout
hidden_states
=
attn
.
to_out
[
1
](
hidden_states
)
...
...
@@ -2103,9 +2119,24 @@ class AuraFlowAttnProcessor2_0:
batch_size
=
hidden_states
.
shape
[
0
]
# `sample` projections.
query
=
attn
.
to_q
(
hidden_states
)
key
=
attn
.
to_k
(
hidden_states
)
value
=
attn
.
to_v
(
hidden_states
)
# query = attn.to_q(hidden_states)
# key = attn.to_k(hidden_states)
# value = attn.to_v(hidden_states)
# DCU OPT: TN->NN
if
isinstance
(
attn
.
to_q
.
bias
,
torch
.
Tensor
):
query
=
torch
.
matmul
(
hidden_states
,
attn
.
to_q
.
weight
.
data
)
+
attn
.
to_q
.
bias
.
data
else
:
query
=
torch
.
matmul
(
hidden_states
,
attn
.
to_q
.
weight
.
data
)
if
isinstance
(
attn
.
to_k
.
bias
,
torch
.
Tensor
):
key
=
torch
.
matmul
(
hidden_states
,
attn
.
to_k
.
weight
.
data
)
+
attn
.
to_k
.
bias
.
data
else
:
key
=
torch
.
matmul
(
hidden_states
,
attn
.
to_k
.
weight
.
data
)
if
isinstance
(
attn
.
to_v
.
bias
,
torch
.
Tensor
):
value
=
torch
.
matmul
(
hidden_states
,
attn
.
to_v
.
weight
.
data
)
+
attn
.
to_v
.
bias
.
data
else
:
value
=
torch
.
matmul
(
hidden_states
,
attn
.
to_v
.
weight
.
data
)
# `context` projections.
if
encoder_hidden_states
is
not
None
:
...
...
@@ -2164,7 +2195,9 @@ class AuraFlowAttnProcessor2_0:
)
# linear proj
hidden_states
=
attn
.
to_out
[
0
](
hidden_states
)
# hidden_states = attn.to_out[0](hidden_states)
# DCU OPT: TN->NN
hidden_states
=
torch
.
matmul
(
hidden_states
,
attn
.
to_out
[
0
].
weight
.
data
)
+
attn
.
to_out
[
0
].
bias
.
data
# dropout
hidden_states
=
attn
.
to_out
[
1
](
hidden_states
)
if
encoder_hidden_states
is
not
None
:
...
...
@@ -2740,7 +2773,7 @@ class AttnProcessor2_0:
# query = attn.to_q(hidden_states)
# DCU OPT: TN->NN
if
attn
.
to_q
.
bias
:
if
isinstance
(
attn
.
to_q
.
bias
,
torch
.
Tensor
)
:
query
=
torch
.
matmul
(
hidden_states
,
attn
.
to_q
.
weight
.
data
)
+
attn
.
to_q
.
bias
.
data
else
:
query
=
torch
.
matmul
(
hidden_states
,
attn
.
to_q
.
weight
.
data
)
...
...
@@ -2753,11 +2786,11 @@ class AttnProcessor2_0:
# key = attn.to_k(encoder_hidden_states)
# value = attn.to_v(encoder_hidden_states)
# DCU OPT: TN->NN
if
attn
.
to_k
.
bias
:
if
isinstance
(
attn
.
to_k
.
bias
,
torch
.
Tensor
)
:
key
=
torch
.
matmul
(
encoder_hidden_states
,
attn
.
to_k
.
weight
.
data
)
+
attn
.
to_k
.
bias
.
data
else
:
key
=
torch
.
matmul
(
encoder_hidden_states
,
attn
.
to_k
.
weight
.
data
)
if
attn
.
to_v
.
bias
:
if
isinstance
(
attn
.
to_v
.
bias
,
torch
.
Tensor
)
:
value
=
torch
.
matmul
(
encoder_hidden_states
,
attn
.
to_v
.
weight
.
data
)
+
attn
.
to_v
.
bias
.
data
else
:
value
=
torch
.
matmul
(
encoder_hidden_states
,
attn
.
to_v
.
weight
.
data
)
...
...
@@ -2765,24 +2798,9 @@ class AttnProcessor2_0:
inner_dim
=
key
.
shape
[
-
1
]
head_dim
=
inner_dim
//
attn
.
heads
# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# DCU OPT: TN->NN
if
isinstance
(
attn
.
to_q
.
bias
,
torch
.
Tensor
):
query
=
torch
.
matmul
(
hidden_states
,
attn
.
to_q
.
weight
.
data
)
+
attn
.
to_q
.
bias
.
data
else
:
query
=
torch
.
matmul
(
hidden_states
,
attn
.
to_q
.
weight
.
data
)
if
isinstance
(
attn
.
to_k
.
bias
,
torch
.
Tensor
):
key
=
torch
.
matmul
(
hidden_states
,
attn
.
to_k
.
weight
.
data
)
+
attn
.
to_k
.
bias
.
data
else
:
key
=
torch
.
matmul
(
hidden_states
,
attn
.
to_k
.
weight
.
data
)
if
isinstance
(
attn
.
to_v
.
bias
,
torch
.
Tensor
):
value
=
torch
.
matmul
(
hidden_states
,
attn
.
to_v
.
weight
.
data
)
+
attn
.
to_v
.
bias
.
data
else
:
value
=
torch
.
matmul
(
hidden_states
,
attn
.
to_v
.
weight
.
data
)
query
=
query
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
key
=
key
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
value
=
value
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
if
attn
.
norm_q
is
not
None
:
query
=
attn
.
norm_q
(
query
)
...
...
src/diffusers/models/model_loading_utils.py
View file @
e47baefc
...
...
@@ -307,47 +307,51 @@ def load_model_dict_into_meta(
set_module_tensor_to_device
(
model
,
param_name
,
param_device
,
value
=
param
,
**
set_module_kwargs
)
# DCU OPT: TN->NN
for
param_name
,
param
in
model
.
named_parameters
():
if
'weight'
in
param_name
and
'add_embedding.linear_1'
in
param_name
:
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
raise
ValueError
(
"rzc test error"
)
if
'weight'
in
param_name
and
'add_embedding.linear_2'
in
param_name
:
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
raise
ValueError
(
"rzc test error"
)
if
'weight'
in
param_name
and
'ff.net'
in
param_name
:
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
raise
ValueError
(
"lijian test error"
)
if
'weight'
in
param_name
and
'time_emb_proj'
in
param_name
:
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
raise
ValueError
(
"lijian test error"
)
if
'weight'
in
param_name
and
'attn'
in
param_name
and
(
'to_q'
in
param_name
or
'to_k'
in
param_name
or
'to_v'
in
param_name
or
'to_out'
in
param_name
):
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
raise
ValueError
(
"lijian test error"
)
if
'weight'
in
param_name
and
'time_embedding'
in
param_name
and
(
'linear_1'
in
param_name
or
'linear_2'
in
param_name
):
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
raise
ValueError
(
"transpose weight to NN error"
)
if
'weight'
in
param_name
and
'attentions'
in
param_name
and
(
'proj_in'
in
param_name
or
'proj_out'
in
param_name
):
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
raise
ValueError
(
"transpose weight to NN error"
)
if
'weight'
in
param_name
and
'decoder.mid_block.attentions.0'
in
param_name
and
(
'to_q'
in
param_name
or
'to_k'
in
param_name
or
'to_v'
in
param_name
or
'to_out'
in
param_name
):
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
raise
ValueError
(
"lijian test error"
)
# add sxx TN->NN
for
param_name
,
param
in
model
.
named_parameters
():
if
'weight'
in
param_name
and
'add_embedding.linear_1'
in
param_name
:
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
raise
ValueError
(
"rzc test error"
)
if
'weight'
in
param_name
and
'add_embedding.linear_2'
in
param_name
:
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
raise
ValueError
(
"rzc test error"
)
if
'weight'
in
param_name
and
'ff.net'
in
param_name
:
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
raise
ValueError
(
"lijian test error"
)
if
'weight'
in
param_name
and
'time_emb_proj'
in
param_name
:
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
raise
ValueError
(
"lijian test error"
)
if
'weight'
in
param_name
and
'attn'
in
param_name
and
(
'to_q'
in
param_name
or
'to_k'
in
param_name
or
'to_v'
in
param_name
or
'to_out'
in
param_name
):
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
#continue
raise
ValueError
(
"lijian test error"
)
if
'weight'
in
param_name
and
'time_embedding'
in
param_name
and
(
'linear_1'
in
param_name
or
'linear_2'
in
param_name
):
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
raise
ValueError
(
"transpose weight to NN error"
)
if
'weight'
in
param_name
and
'attentions'
in
param_name
and
(
'proj_in'
in
param_name
or
'proj_out'
in
param_name
):
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
raise
ValueError
(
"transpose weight to NN error"
)
if
'weight'
in
param_name
and
'decoder.mid_block.attentions.0'
in
param_name
and
(
'to_q'
in
param_name
or
'to_k'
in
param_name
or
'to_v'
in
param_name
or
'to_out'
in
param_name
):
if
param
.
data
.
dim
()
==
2
:
param
.
data
=
param
.
data
.
permute
(
1
,
0
).
contiguous
()
else
:
#continue
raise
ValueError
(
"lijian test error"
)
return
offload_index
,
state_dict_index
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment