Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d32e9391
Unverified
Commit
d32e9391
authored
Feb 16, 2023
by
fxmarty
Committed by
GitHub
Feb 16, 2023
Browse files
Replace torch.concat calls by torch.cat (#2378)
replace torch.concat by torch.cat
parent
aaaec064
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
src/diffusers/models/cross_attention.py
src/diffusers/models/cross_attention.py
+5
-5
No files found.
src/diffusers/models/cross_attention.py
View file @
d32e9391
...
@@ -275,7 +275,7 @@ class CrossAttention(nn.Module):
...
@@ -275,7 +275,7 @@ class CrossAttention(nn.Module):
# Instead, we can manually construct the padding tensor.
# Instead, we can manually construct the padding tensor.
padding_shape
=
(
attention_mask
.
shape
[
0
],
attention_mask
.
shape
[
1
],
target_length
)
padding_shape
=
(
attention_mask
.
shape
[
0
],
attention_mask
.
shape
[
1
],
target_length
)
padding
=
torch
.
zeros
(
padding_shape
,
dtype
=
attention_mask
.
dtype
,
device
=
attention_mask
.
device
)
padding
=
torch
.
zeros
(
padding_shape
,
dtype
=
attention_mask
.
dtype
,
device
=
attention_mask
.
device
)
attention_mask
=
torch
.
con
cat
([
attention_mask
,
padding
],
dim
=
2
)
attention_mask
=
torch
.
cat
([
attention_mask
,
padding
],
dim
=
2
)
else
:
else
:
attention_mask
=
F
.
pad
(
attention_mask
,
(
0
,
target_length
),
value
=
0.0
)
attention_mask
=
F
.
pad
(
attention_mask
,
(
0
,
target_length
),
value
=
0.0
)
...
@@ -409,8 +409,8 @@ class CrossAttnAddedKVProcessor:
...
@@ -409,8 +409,8 @@ class CrossAttnAddedKVProcessor:
encoder_hidden_states_key_proj
=
attn
.
head_to_batch_dim
(
encoder_hidden_states_key_proj
)
encoder_hidden_states_key_proj
=
attn
.
head_to_batch_dim
(
encoder_hidden_states_key_proj
)
encoder_hidden_states_value_proj
=
attn
.
head_to_batch_dim
(
encoder_hidden_states_value_proj
)
encoder_hidden_states_value_proj
=
attn
.
head_to_batch_dim
(
encoder_hidden_states_value_proj
)
key
=
torch
.
con
cat
([
encoder_hidden_states_key_proj
,
key
],
dim
=
1
)
key
=
torch
.
cat
([
encoder_hidden_states_key_proj
,
key
],
dim
=
1
)
value
=
torch
.
con
cat
([
encoder_hidden_states_value_proj
,
value
],
dim
=
1
)
value
=
torch
.
cat
([
encoder_hidden_states_value_proj
,
value
],
dim
=
1
)
attention_probs
=
attn
.
get_attention_scores
(
query
,
key
,
attention_mask
)
attention_probs
=
attn
.
get_attention_scores
(
query
,
key
,
attention_mask
)
hidden_states
=
torch
.
bmm
(
attention_probs
,
value
)
hidden_states
=
torch
.
bmm
(
attention_probs
,
value
)
...
@@ -588,8 +588,8 @@ class SlicedAttnAddedKVProcessor:
...
@@ -588,8 +588,8 @@ class SlicedAttnAddedKVProcessor:
encoder_hidden_states_key_proj
=
attn
.
head_to_batch_dim
(
encoder_hidden_states_key_proj
)
encoder_hidden_states_key_proj
=
attn
.
head_to_batch_dim
(
encoder_hidden_states_key_proj
)
encoder_hidden_states_value_proj
=
attn
.
head_to_batch_dim
(
encoder_hidden_states_value_proj
)
encoder_hidden_states_value_proj
=
attn
.
head_to_batch_dim
(
encoder_hidden_states_value_proj
)
key
=
torch
.
con
cat
([
encoder_hidden_states_key_proj
,
key
],
dim
=
1
)
key
=
torch
.
cat
([
encoder_hidden_states_key_proj
,
key
],
dim
=
1
)
value
=
torch
.
con
cat
([
encoder_hidden_states_value_proj
,
value
],
dim
=
1
)
value
=
torch
.
cat
([
encoder_hidden_states_value_proj
,
value
],
dim
=
1
)
batch_size_attention
=
query
.
shape
[
0
]
batch_size_attention
=
query
.
shape
[
0
]
hidden_states
=
torch
.
zeros
(
hidden_states
=
torch
.
zeros
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment