Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
4b957a00
Commit
4b957a00
authored
Jul 29, 2023
by
comfyanonymous
Browse files
Initialize the unet directly on the target device.
parent
ad5866b0
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
110 additions
and
103 deletions
+110
-103
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+47
-47
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+44
-36
comfy/model_base.py
comfy/model_base.py
+10
-10
comfy/sd.py
comfy/sd.py
+1
-2
comfy/supported_models.py
comfy/supported_models.py
+4
-4
comfy/supported_models_base.py
comfy/supported_models_base.py
+4
-4
No files found.
comfy/ldm/modules/attention.py
View file @
4b957a00
...
...
@@ -52,9 +52,9 @@ def init_(tensor):
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
,
dtype
=
None
):
def
__init__
(
self
,
dim_in
,
dim_out
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
self
.
proj
=
comfy
.
ops
.
Linear
(
dim_in
,
dim_out
*
2
,
dtype
=
dtype
)
self
.
proj
=
comfy
.
ops
.
Linear
(
dim_in
,
dim_out
*
2
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
...
...
@@ -62,19 +62,19 @@ class GEGLU(nn.Module):
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
,
dtype
=
None
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
dim
,
inner_dim
,
dtype
=
dtype
),
comfy
.
ops
.
Linear
(
dim
,
inner_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
GELU
()
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
,
dtype
=
dtype
)
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
comfy
.
ops
.
Linear
(
inner_dim
,
dim_out
,
dtype
=
dtype
)
comfy
.
ops
.
Linear
(
inner_dim
,
dim_out
,
dtype
=
dtype
,
device
=
device
)
)
def
forward
(
self
,
x
):
...
...
@@ -90,8 +90,8 @@ def zero_module(module):
return
module
def
Normalize
(
in_channels
,
dtype
=
None
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
,
dtype
=
dtype
)
def
Normalize
(
in_channels
,
dtype
=
None
,
device
=
None
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
,
dtype
=
dtype
,
device
=
device
)
class
SpatialSelfAttention
(
nn
.
Module
):
...
...
@@ -148,7 +148,7 @@ class SpatialSelfAttention(nn.Module):
class
CrossAttentionBirchSan
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
...
...
@@ -156,12 +156,12 @@ class CrossAttentionBirchSan(nn.Module):
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
)
)
...
...
@@ -245,7 +245,7 @@ class CrossAttentionBirchSan(nn.Module):
class
CrossAttentionDoggettx
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
...
...
@@ -253,12 +253,12 @@ class CrossAttentionDoggettx(nn.Module):
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
)
)
...
...
@@ -343,7 +343,7 @@ class CrossAttentionDoggettx(nn.Module):
return
self
.
to_out
(
r2
)
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
...
...
@@ -351,12 +351,12 @@ class CrossAttention(nn.Module):
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
)
)
...
...
@@ -399,7 +399,7 @@ class CrossAttention(nn.Module):
class
MemoryEfficientCrossAttention
(
nn
.
Module
):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
,
dtype
=
None
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
print
(
f
"Setting up
{
self
.
__class__
.
__name__
}
. Query dim is
{
query_dim
}
, context_dim is
{
context_dim
}
and using "
f
"
{
heads
}
heads."
)
...
...
@@ -409,11 +409,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
))
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
))
self
.
attention_op
:
Optional
[
Any
]
=
None
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
...
...
@@ -450,7 +450,7 @@ class MemoryEfficientCrossAttention(nn.Module):
return
self
.
to_out
(
out
)
class
CrossAttentionPytorch
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
...
...
@@ -458,11 +458,11 @@ class CrossAttentionPytorch(nn.Module):
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
))
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
))
self
.
attention_op
:
Optional
[
Any
]
=
None
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
...
...
@@ -508,17 +508,17 @@ else:
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
disable_self_attn
=
False
,
dtype
=
None
):
disable_self_attn
=
False
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
self
.
disable_self_attn
=
disable_self_attn
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
if
self
.
disable_self_attn
else
None
,
dtype
=
dtype
)
# is a self-attention if not self.disable_self_attn
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
,
dtype
=
dtype
)
context_dim
=
context_dim
if
self
.
disable_self_attn
else
None
,
dtype
=
dtype
,
device
=
device
)
# is a self-attention if not self.disable_self_attn
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
,
dtype
=
dtype
,
device
=
device
)
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
dtype
=
dtype
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
)
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
dtype
=
dtype
,
device
=
device
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
checkpoint
=
checkpoint
self
.
n_heads
=
n_heads
self
.
d_head
=
d_head
...
...
@@ -648,34 +648,34 @@ class SpatialTransformer(nn.Module):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.
,
context_dim
=
None
,
disable_self_attn
=
False
,
use_linear
=
False
,
use_checkpoint
=
True
,
dtype
=
None
):
use_checkpoint
=
True
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
if
exists
(
context_dim
)
and
not
isinstance
(
context_dim
,
list
):
context_dim
=
[
context_dim
]
*
depth
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
self
.
norm
=
Normalize
(
in_channels
,
dtype
=
dtype
)
self
.
norm
=
Normalize
(
in_channels
,
dtype
=
dtype
,
device
=
device
)
if
not
use_linear
:
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
dtype
=
dtype
)
padding
=
0
,
dtype
=
dtype
,
device
=
device
)
else
:
self
.
proj_in
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
)
self
.
proj_in
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
[
d
],
disable_self_attn
=
disable_self_attn
,
checkpoint
=
use_checkpoint
,
dtype
=
dtype
)
disable_self_attn
=
disable_self_attn
,
checkpoint
=
use_checkpoint
,
dtype
=
dtype
,
device
=
device
)
for
d
in
range
(
depth
)]
)
if
not
use_linear
:
self
.
proj_out
=
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
dtype
=
dtype
)
padding
=
0
,
dtype
=
dtype
,
device
=
device
)
else
:
self
.
proj_out
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
)
self
.
proj_out
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
use_linear
=
use_linear
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}):
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
4b957a00
...
...
@@ -111,14 +111,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
):
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
,
dtype
=
dtype
)
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
,
output_shape
=
None
):
assert
x
.
shape
[
1
]
==
self
.
channels
...
...
@@ -160,7 +160,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
):
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
...
...
@@ -169,7 +169,7 @@ class Downsample(nn.Module):
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
,
dtype
=
dtype
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
,
dtype
=
dtype
,
device
=
device
)
else
:
assert
self
.
channels
==
self
.
out_channels
...
...
@@ -208,7 +208,8 @@ class ResBlock(TimestepBlock):
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
dtype
=
None
dtype
=
None
,
device
=
None
,
):
super
().
__init__
()
self
.
channels
=
channels
...
...
@@ -220,19 +221,19 @@ class ResBlock(TimestepBlock):
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
channels
,
dtype
=
dtype
),
nn
.
GroupNorm
(
32
,
channels
,
dtype
=
dtype
,
device
=
device
),
nn
.
SiLU
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
,
device
=
device
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
,
device
=
device
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
,
device
=
device
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
,
device
=
device
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
...
...
@@ -240,15 +241,15 @@ class ResBlock(TimestepBlock):
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
dtype
=
dtype
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
dtype
=
dtype
,
device
=
device
),
)
self
.
out_layers
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
self
.
out_channels
,
dtype
=
dtype
),
nn
.
GroupNorm
(
32
,
self
.
out_channels
,
dtype
=
dtype
,
device
=
device
),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
)
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
)
),
)
...
...
@@ -256,10 +257,10 @@ class ResBlock(TimestepBlock):
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
,
dtype
=
dtype
)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
,
emb
):
"""
...
...
@@ -503,6 +504,7 @@ class UNetModel(nn.Module):
use_linear_in_transformer
=
False
,
adm_in_channels
=
None
,
transformer_depth_middle
=
None
,
device
=
None
,
):
super
().
__init__
()
if
use_spatial_transformer
:
...
...
@@ -564,9 +566,9 @@ class UNetModel(nn.Module):
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
),
linear
(
model_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
,
dtype
=
self
.
dtype
),
linear
(
time_embed_dim
,
time_embed_dim
,
dtype
=
self
.
dtype
,
device
=
device
),
)
if
self
.
num_classes
is
not
None
:
...
...
@@ -579,9 +581,9 @@ class UNetModel(nn.Module):
assert
adm_in_channels
is
not
None
self
.
label_emb
=
nn
.
Sequential
(
nn
.
Sequential
(
linear
(
adm_in_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
),
linear
(
adm_in_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
,
dtype
=
self
.
dtype
),
linear
(
time_embed_dim
,
time_embed_dim
,
dtype
=
self
.
dtype
,
device
=
device
),
)
)
else
:
...
...
@@ -590,7 +592,7 @@ class UNetModel(nn.Module):
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
)
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
,
device
=
device
)
)
]
)
...
...
@@ -609,7 +611,8 @@ class UNetModel(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
)
]
ch
=
mult
*
model_channels
...
...
@@ -638,7 +641,7 @@ class UNetModel(nn.Module):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
[
level
],
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
...
@@ -657,11 +660,12 @@ class UNetModel(nn.Module):
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
,
device
=
device
)
)
)
...
...
@@ -686,7 +690,8 @@ class UNetModel(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
),
AttentionBlock
(
ch
,
...
...
@@ -697,7 +702,7 @@ class UNetModel(nn.Module):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
),
ResBlock
(
ch
,
...
...
@@ -706,7 +711,8 @@ class UNetModel(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
),
)
self
.
_feature_size
+=
ch
...
...
@@ -724,7 +730,8 @@ class UNetModel(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
)
]
ch
=
model_channels
*
mult
...
...
@@ -753,7 +760,7 @@ class UNetModel(nn.Module):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
[
level
],
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
)
)
if
level
and
i
==
self
.
num_res_blocks
[
level
]:
...
...
@@ -768,24 +775,25 @@ class UNetModel(nn.Module):
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
up
=
True
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
)
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
)
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
,
device
=
device
)
)
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
out
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
ch
,
dtype
=
self
.
dtype
),
nn
.
GroupNorm
(
32
,
ch
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
zero_module
(
conv_nd
(
dims
,
model_channels
,
out_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
)),
zero_module
(
conv_nd
(
dims
,
model_channels
,
out_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
,
device
=
device
)),
)
if
self
.
predict_codebook_ids
:
self
.
id_predictor
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
ch
,
dtype
=
self
.
dtype
),
conv_nd
(
dims
,
model_channels
,
n_embed
,
1
),
nn
.
GroupNorm
(
32
,
ch
,
dtype
=
self
.
dtype
,
device
=
device
),
conv_nd
(
dims
,
model_channels
,
n_embed
,
1
,
dtype
=
self
.
dtype
,
device
=
device
),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
...
...
comfy/model_base.py
View file @
4b957a00
...
...
@@ -12,14 +12,14 @@ class ModelType(Enum):
V_PREDICTION
=
2
class
BaseModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
,
device
=
None
):
super
().
__init__
()
unet_config
=
model_config
.
unet_config
self
.
latent_format
=
model_config
.
latent_format
self
.
model_config
=
model_config
self
.
register_schedule
(
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
linear_start
=
0.00085
,
linear_end
=
0.012
,
cosine_s
=
8e-3
)
self
.
diffusion_model
=
UNetModel
(
**
unet_config
)
self
.
diffusion_model
=
UNetModel
(
**
unet_config
,
device
=
device
)
self
.
model_type
=
model_type
self
.
adm_channels
=
unet_config
.
get
(
"adm_in_channels"
,
None
)
if
self
.
adm_channels
is
None
:
...
...
@@ -107,8 +107,8 @@ class BaseModel(torch.nn.Module):
class
SD21UNCLIP
(
BaseModel
):
def
__init__
(
self
,
model_config
,
noise_aug_config
,
model_type
=
ModelType
.
V_PREDICTION
):
super
().
__init__
(
model_config
,
model_type
)
def
__init__
(
self
,
model_config
,
noise_aug_config
,
model_type
=
ModelType
.
V_PREDICTION
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
noise_augmentor
=
CLIPEmbeddingNoiseAugmentation
(
**
noise_aug_config
)
def
encode_adm
(
self
,
**
kwargs
):
...
...
@@ -143,13 +143,13 @@ class SD21UNCLIP(BaseModel):
return
adm_out
class
SDInpaint
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
super
().
__init__
(
model_config
,
model_type
)
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
concat_keys
=
(
"mask"
,
"masked_image"
)
class
SDXLRefiner
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
super
().
__init__
(
model_config
,
model_type
)
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
...
...
@@ -174,8 +174,8 @@ class SDXLRefiner(BaseModel):
return
torch
.
cat
((
clip_pooled
.
to
(
flat
.
device
),
flat
),
dim
=
1
)
class
SDXL
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
super
().
__init__
(
model_config
,
model_type
)
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
...
...
comfy/sd.py
View file @
4b957a00
...
...
@@ -1169,8 +1169,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clipvision
=
clip_vision
.
load_clipvision_from_sd
(
sd
,
model_config
.
clip_vision_prefix
,
True
)
offload_device
=
model_management
.
unet_offload_device
()
model
=
model_config
.
get_model
(
sd
,
"model.diffusion_model."
)
model
=
model
.
to
(
offload_device
)
model
=
model_config
.
get_model
(
sd
,
"model.diffusion_model."
,
device
=
offload_device
)
model
.
load_model_weights
(
sd
,
"model.diffusion_model."
)
if
output_vae
:
...
...
comfy/supported_models.py
View file @
4b957a00
...
...
@@ -109,8 +109,8 @@ class SDXLRefiner(supported_models_base.BASE):
latent_format
=
latent_formats
.
SDXL
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
return
model_base
.
SDXLRefiner
(
self
)
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
return
model_base
.
SDXLRefiner
(
self
,
device
=
device
)
def
process_clip_state_dict
(
self
,
state_dict
):
keys_to_replace
=
{}
...
...
@@ -152,8 +152,8 @@ class SDXL(supported_models_base.BASE):
else
:
return
model_base
.
ModelType
.
EPS
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
return
model_base
.
SDXL
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
return
model_base
.
SDXL
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
)
,
device
=
device
)
def
process_clip_state_dict
(
self
,
state_dict
):
keys_to_replace
=
{}
...
...
comfy/supported_models_base.py
View file @
4b957a00
...
...
@@ -53,13 +53,13 @@ class BASE:
for
x
in
self
.
unet_extra_config
:
self
.
unet_config
[
x
]
=
self
.
unet_extra_config
[
x
]
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
if
self
.
inpaint_model
():
return
model_base
.
SDInpaint
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
return
model_base
.
SDInpaint
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
)
,
device
=
device
)
elif
self
.
noise_aug_config
is
not
None
:
return
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
return
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
)
,
device
=
device
)
else
:
return
model_base
.
BaseModel
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
return
model_base
.
BaseModel
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
)
,
device
=
device
)
def
process_clip_state_dict
(
self
,
state_dict
):
return
state_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment