Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
e21d9ad4
"...git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "afa2399f79e84919645eb69cd8e1ef1d9f1d6bd1"
Commit
e21d9ad4
authored
Jun 15, 2023
by
comfyanonymous
Browse files
Initialize transformer unet block weights in right dtype at the start.
parent
6253ec4a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
44 deletions
+44
-44
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+41
-41
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+3
-3
No files found.
comfy/ldm/modules/attention.py
View file @
e21d9ad4
...
@@ -61,19 +61,19 @@ class GEGLU(nn.Module):
...
@@ -61,19 +61,19 @@ class GEGLU(nn.Module):
class
FeedForward
(
nn
.
Module
):
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
project_in
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
dim
,
inner_dim
),
comfy
.
ops
.
Linear
(
dim
,
inner_dim
,
dtype
=
dtype
),
nn
.
GELU
()
nn
.
GELU
()
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
self
.
net
=
nn
.
Sequential
(
self
.
net
=
nn
.
Sequential
(
project_in
,
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Dropout
(
dropout
),
comfy
.
ops
.
Linear
(
inner_dim
,
dim_out
)
comfy
.
ops
.
Linear
(
inner_dim
,
dim_out
,
dtype
=
dtype
)
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -147,7 +147,7 @@ class SpatialSelfAttention(nn.Module):
...
@@ -147,7 +147,7 @@ class SpatialSelfAttention(nn.Module):
class
CrossAttentionBirchSan
(
nn
.
Module
):
class
CrossAttentionBirchSan
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
dim_head
*
heads
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
context_dim
=
default
(
context_dim
,
query_dim
)
...
@@ -155,12 +155,12 @@ class CrossAttentionBirchSan(nn.Module):
...
@@ -155,12 +155,12 @@ class CrossAttentionBirchSan(nn.Module):
self
.
scale
=
dim_head
**
-
0.5
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
heads
=
heads
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_out
=
nn
.
Sequential
(
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
),
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
)
nn
.
Dropout
(
dropout
)
)
)
...
@@ -244,7 +244,7 @@ class CrossAttentionBirchSan(nn.Module):
...
@@ -244,7 +244,7 @@ class CrossAttentionBirchSan(nn.Module):
class
CrossAttentionDoggettx
(
nn
.
Module
):
class
CrossAttentionDoggettx
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
dim_head
*
heads
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
context_dim
=
default
(
context_dim
,
query_dim
)
...
@@ -252,12 +252,12 @@ class CrossAttentionDoggettx(nn.Module):
...
@@ -252,12 +252,12 @@ class CrossAttentionDoggettx(nn.Module):
self
.
scale
=
dim_head
**
-
0.5
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
heads
=
heads
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_out
=
nn
.
Sequential
(
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
),
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
)
nn
.
Dropout
(
dropout
)
)
)
...
@@ -342,7 +342,7 @@ class CrossAttentionDoggettx(nn.Module):
...
@@ -342,7 +342,7 @@ class CrossAttentionDoggettx(nn.Module):
return
self
.
to_out
(
r2
)
return
self
.
to_out
(
r2
)
class
CrossAttention
(
nn
.
Module
):
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
dim_head
*
heads
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
context_dim
=
default
(
context_dim
,
query_dim
)
...
@@ -350,12 +350,12 @@ class CrossAttention(nn.Module):
...
@@ -350,12 +350,12 @@ class CrossAttention(nn.Module):
self
.
scale
=
dim_head
**
-
0.5
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
heads
=
heads
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_out
=
nn
.
Sequential
(
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
),
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
)
nn
.
Dropout
(
dropout
)
)
)
...
@@ -398,7 +398,7 @@ class CrossAttention(nn.Module):
...
@@ -398,7 +398,7 @@ class CrossAttention(nn.Module):
class
MemoryEfficientCrossAttention
(
nn
.
Module
):
class
MemoryEfficientCrossAttention
(
nn
.
Module
):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
print
(
f
"Setting up
{
self
.
__class__
.
__name__
}
. Query dim is
{
query_dim
}
, context_dim is
{
context_dim
}
and using "
print
(
f
"Setting up
{
self
.
__class__
.
__name__
}
. Query dim is
{
query_dim
}
, context_dim is
{
context_dim
}
and using "
f
"
{
heads
}
heads."
)
f
"
{
heads
}
heads."
)
...
@@ -408,11 +408,11 @@ class MemoryEfficientCrossAttention(nn.Module):
...
@@ -408,11 +408,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self
.
heads
=
heads
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
dim_head
=
dim_head
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
))
self
.
attention_op
:
Optional
[
Any
]
=
None
self
.
attention_op
:
Optional
[
Any
]
=
None
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
...
@@ -449,7 +449,7 @@ class MemoryEfficientCrossAttention(nn.Module):
...
@@ -449,7 +449,7 @@ class MemoryEfficientCrossAttention(nn.Module):
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
class
CrossAttentionPytorch
(
nn
.
Module
):
class
CrossAttentionPytorch
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
dim_head
*
heads
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
context_dim
=
default
(
context_dim
,
query_dim
)
...
@@ -457,11 +457,11 @@ class CrossAttentionPytorch(nn.Module):
...
@@ -457,11 +457,11 @@ class CrossAttentionPytorch(nn.Module):
self
.
heads
=
heads
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
dim_head
=
dim_head
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
))
self
.
attention_op
:
Optional
[
Any
]
=
None
self
.
attention_op
:
Optional
[
Any
]
=
None
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
...
@@ -507,17 +507,17 @@ else:
...
@@ -507,17 +507,17 @@ else:
class
BasicTransformerBlock
(
nn
.
Module
):
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
disable_self_attn
=
False
):
disable_self_attn
=
False
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
disable_self_attn
=
disable_self_attn
self
.
disable_self_attn
=
disable_self_attn
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
if
self
.
disable_self_attn
else
None
)
# is a self-attention if not self.disable_self_attn
context_dim
=
context_dim
if
self
.
disable_self_attn
else
None
,
dtype
=
dtype
)
# is a self-attention if not self.disable_self_attn
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
,
dtype
=
dtype
)
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is self-attn if context is none
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
dtype
=
dtype
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm1
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
checkpoint
=
checkpoint
self
.
checkpoint
=
checkpoint
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}):
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}):
...
@@ -588,7 +588,7 @@ class SpatialTransformer(nn.Module):
...
@@ -588,7 +588,7 @@ class SpatialTransformer(nn.Module):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.
,
context_dim
=
None
,
depth
=
1
,
dropout
=
0.
,
context_dim
=
None
,
disable_self_attn
=
False
,
use_linear
=
False
,
disable_self_attn
=
False
,
use_linear
=
False
,
use_checkpoint
=
True
):
use_checkpoint
=
True
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
if
exists
(
context_dim
)
and
not
isinstance
(
context_dim
,
list
):
if
exists
(
context_dim
)
and
not
isinstance
(
context_dim
,
list
):
context_dim
=
[
context_dim
]
context_dim
=
[
context_dim
]
...
@@ -600,22 +600,22 @@ class SpatialTransformer(nn.Module):
...
@@ -600,22 +600,22 @@ class SpatialTransformer(nn.Module):
inner_dim
,
inner_dim
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
)
padding
=
0
,
dtype
=
dtype
)
else
:
else
:
self
.
proj_in
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
)
self
.
proj_in
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
)
self
.
transformer_blocks
=
nn
.
ModuleList
(
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
[
d
],
[
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
[
d
],
disable_self_attn
=
disable_self_attn
,
checkpoint
=
use_checkpoint
)
disable_self_attn
=
disable_self_attn
,
checkpoint
=
use_checkpoint
,
dtype
=
dtype
)
for
d
in
range
(
depth
)]
for
d
in
range
(
depth
)]
)
)
if
not
use_linear
:
if
not
use_linear
:
self
.
proj_out
=
nn
.
Conv2d
(
inner_dim
,
in_channels
,
self
.
proj_out
=
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
)
padding
=
0
,
dtype
=
dtype
)
else
:
else
:
self
.
proj_out
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
)
self
.
proj_out
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
)
self
.
use_linear
=
use_linear
self
.
use_linear
=
use_linear
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}):
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}):
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
e21d9ad4
...
@@ -631,7 +631,7 @@ class UNetModel(nn.Module):
...
@@ -631,7 +631,7 @@ class UNetModel(nn.Module):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
)
)
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
@@ -688,7 +688,7 @@ class UNetModel(nn.Module):
...
@@ -688,7 +688,7 @@ class UNetModel(nn.Module):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
# always uses a self-attn
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
),
),
ResBlock
(
ResBlock
(
ch
,
ch
,
...
@@ -742,7 +742,7 @@ class UNetModel(nn.Module):
...
@@ -742,7 +742,7 @@ class UNetModel(nn.Module):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
)
)
)
)
if
level
and
i
==
self
.
num_res_blocks
[
level
]:
if
level
and
i
==
self
.
num_res_blocks
[
level
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment