Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
e21d9ad4
Commit
e21d9ad4
authored
Jun 15, 2023
by
comfyanonymous
Browse files
Initialize transformer unet block weights in right dtype at the start.
parent
6253ec4a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
44 deletions
+44
-44
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+41
-41
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+3
-3
No files found.
comfy/ldm/modules/attention.py
View file @
e21d9ad4
...
@@ -61,19 +61,19 @@ class GEGLU(nn.Module):
...
@@ -61,19 +61,19 @@ class GEGLU(nn.Module):
class
FeedForward
(
nn
.
Module
):
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
project_in
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
dim
,
inner_dim
),
comfy
.
ops
.
Linear
(
dim
,
inner_dim
,
dtype
=
dtype
),
nn
.
GELU
()
nn
.
GELU
()
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
self
.
net
=
nn
.
Sequential
(
self
.
net
=
nn
.
Sequential
(
project_in
,
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Dropout
(
dropout
),
comfy
.
ops
.
Linear
(
inner_dim
,
dim_out
)
comfy
.
ops
.
Linear
(
inner_dim
,
dim_out
,
dtype
=
dtype
)
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -147,7 +147,7 @@ class SpatialSelfAttention(nn.Module):
...
@@ -147,7 +147,7 @@ class SpatialSelfAttention(nn.Module):
class
CrossAttentionBirchSan
(
nn
.
Module
):
class
CrossAttentionBirchSan
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
dim_head
*
heads
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
context_dim
=
default
(
context_dim
,
query_dim
)
...
@@ -155,12 +155,12 @@ class CrossAttentionBirchSan(nn.Module):
...
@@ -155,12 +155,12 @@ class CrossAttentionBirchSan(nn.Module):
self
.
scale
=
dim_head
**
-
0.5
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
heads
=
heads
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_out
=
nn
.
Sequential
(
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
),
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
)
nn
.
Dropout
(
dropout
)
)
)
...
@@ -244,7 +244,7 @@ class CrossAttentionBirchSan(nn.Module):
...
@@ -244,7 +244,7 @@ class CrossAttentionBirchSan(nn.Module):
class
CrossAttentionDoggettx
(
nn
.
Module
):
class
CrossAttentionDoggettx
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
dim_head
*
heads
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
context_dim
=
default
(
context_dim
,
query_dim
)
...
@@ -252,12 +252,12 @@ class CrossAttentionDoggettx(nn.Module):
...
@@ -252,12 +252,12 @@ class CrossAttentionDoggettx(nn.Module):
self
.
scale
=
dim_head
**
-
0.5
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
heads
=
heads
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_out
=
nn
.
Sequential
(
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
),
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
)
nn
.
Dropout
(
dropout
)
)
)
...
@@ -342,7 +342,7 @@ class CrossAttentionDoggettx(nn.Module):
...
@@ -342,7 +342,7 @@ class CrossAttentionDoggettx(nn.Module):
return
self
.
to_out
(
r2
)
return
self
.
to_out
(
r2
)
class
CrossAttention
(
nn
.
Module
):
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
dim_head
*
heads
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
context_dim
=
default
(
context_dim
,
query_dim
)
...
@@ -350,12 +350,12 @@ class CrossAttention(nn.Module):
...
@@ -350,12 +350,12 @@ class CrossAttention(nn.Module):
self
.
scale
=
dim_head
**
-
0.5
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
heads
=
heads
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_out
=
nn
.
Sequential
(
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
),
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
)
nn
.
Dropout
(
dropout
)
)
)
...
@@ -398,7 +398,7 @@ class CrossAttention(nn.Module):
...
@@ -398,7 +398,7 @@ class CrossAttention(nn.Module):
class
MemoryEfficientCrossAttention
(
nn
.
Module
):
class
MemoryEfficientCrossAttention
(
nn
.
Module
):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
print
(
f
"Setting up
{
self
.
__class__
.
__name__
}
. Query dim is
{
query_dim
}
, context_dim is
{
context_dim
}
and using "
print
(
f
"Setting up
{
self
.
__class__
.
__name__
}
. Query dim is
{
query_dim
}
, context_dim is
{
context_dim
}
and using "
f
"
{
heads
}
heads."
)
f
"
{
heads
}
heads."
)
...
@@ -408,11 +408,11 @@ class MemoryEfficientCrossAttention(nn.Module):
...
@@ -408,11 +408,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self
.
heads
=
heads
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
dim_head
=
dim_head
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
))
self
.
attention_op
:
Optional
[
Any
]
=
None
self
.
attention_op
:
Optional
[
Any
]
=
None
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
...
@@ -449,7 +449,7 @@ class MemoryEfficientCrossAttention(nn.Module):
...
@@ -449,7 +449,7 @@ class MemoryEfficientCrossAttention(nn.Module):
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
class
CrossAttentionPytorch
(
nn
.
Module
):
class
CrossAttentionPytorch
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
dim_head
*
heads
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
context_dim
=
default
(
context_dim
,
query_dim
)
...
@@ -457,11 +457,11 @@ class CrossAttentionPytorch(nn.Module):
...
@@ -457,11 +457,11 @@ class CrossAttentionPytorch(nn.Module):
self
.
heads
=
heads
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
dim_head
=
dim_head
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
))
self
.
attention_op
:
Optional
[
Any
]
=
None
self
.
attention_op
:
Optional
[
Any
]
=
None
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
...
@@ -507,17 +507,17 @@ else:
...
@@ -507,17 +507,17 @@ else:
class
BasicTransformerBlock
(
nn
.
Module
):
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
disable_self_attn
=
False
):
disable_self_attn
=
False
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
disable_self_attn
=
disable_self_attn
self
.
disable_self_attn
=
disable_self_attn
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
if
self
.
disable_self_attn
else
None
)
# is a self-attention if not self.disable_self_attn
context_dim
=
context_dim
if
self
.
disable_self_attn
else
None
,
dtype
=
dtype
)
# is a self-attention if not self.disable_self_attn
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
,
dtype
=
dtype
)
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is self-attn if context is none
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
dtype
=
dtype
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm1
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
checkpoint
=
checkpoint
self
.
checkpoint
=
checkpoint
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}):
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}):
...
@@ -588,7 +588,7 @@ class SpatialTransformer(nn.Module):
...
@@ -588,7 +588,7 @@ class SpatialTransformer(nn.Module):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.
,
context_dim
=
None
,
depth
=
1
,
dropout
=
0.
,
context_dim
=
None
,
disable_self_attn
=
False
,
use_linear
=
False
,
disable_self_attn
=
False
,
use_linear
=
False
,
use_checkpoint
=
True
):
use_checkpoint
=
True
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
if
exists
(
context_dim
)
and
not
isinstance
(
context_dim
,
list
):
if
exists
(
context_dim
)
and
not
isinstance
(
context_dim
,
list
):
context_dim
=
[
context_dim
]
context_dim
=
[
context_dim
]
...
@@ -600,22 +600,22 @@ class SpatialTransformer(nn.Module):
...
@@ -600,22 +600,22 @@ class SpatialTransformer(nn.Module):
inner_dim
,
inner_dim
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
)
padding
=
0
,
dtype
=
dtype
)
else
:
else
:
self
.
proj_in
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
)
self
.
proj_in
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
)
self
.
transformer_blocks
=
nn
.
ModuleList
(
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
[
d
],
[
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
[
d
],
disable_self_attn
=
disable_self_attn
,
checkpoint
=
use_checkpoint
)
disable_self_attn
=
disable_self_attn
,
checkpoint
=
use_checkpoint
,
dtype
=
dtype
)
for
d
in
range
(
depth
)]
for
d
in
range
(
depth
)]
)
)
if
not
use_linear
:
if
not
use_linear
:
self
.
proj_out
=
nn
.
Conv2d
(
inner_dim
,
in_channels
,
self
.
proj_out
=
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
)
padding
=
0
,
dtype
=
dtype
)
else
:
else
:
self
.
proj_out
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
)
self
.
proj_out
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
)
self
.
use_linear
=
use_linear
self
.
use_linear
=
use_linear
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}):
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}):
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
e21d9ad4
...
@@ -631,7 +631,7 @@ class UNetModel(nn.Module):
...
@@ -631,7 +631,7 @@ class UNetModel(nn.Module):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
)
)
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
@@ -688,7 +688,7 @@ class UNetModel(nn.Module):
...
@@ -688,7 +688,7 @@ class UNetModel(nn.Module):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
# always uses a self-attn
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
),
),
ResBlock
(
ResBlock
(
ch
,
ch
,
...
@@ -742,7 +742,7 @@ class UNetModel(nn.Module):
...
@@ -742,7 +742,7 @@ class UNetModel(nn.Module):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
)
)
)
)
if
level
and
i
==
self
.
num_res_blocks
[
level
]:
if
level
and
i
==
self
.
num_res_blocks
[
level
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment