Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
4b957a00
"...lm-evaluation-harness.git" did not exist on "adbfcce140da7bafd70a3c2d83c7d6e970b35197"
Commit
4b957a00
authored
Jul 29, 2023
by
comfyanonymous
Browse files
Initialize the unet directly on the target device.
parent
ad5866b0
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
110 additions
and
103 deletions
+110
-103
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+47
-47
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+44
-36
comfy/model_base.py
comfy/model_base.py
+10
-10
comfy/sd.py
comfy/sd.py
+1
-2
comfy/supported_models.py
comfy/supported_models.py
+4
-4
comfy/supported_models_base.py
comfy/supported_models_base.py
+4
-4
No files found.
comfy/ldm/modules/attention.py
View file @
4b957a00
...
...
@@ -52,9 +52,9 @@ def init_(tensor):
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
,
dtype
=
None
):
def
__init__
(
self
,
dim_in
,
dim_out
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
self
.
proj
=
comfy
.
ops
.
Linear
(
dim_in
,
dim_out
*
2
,
dtype
=
dtype
)
self
.
proj
=
comfy
.
ops
.
Linear
(
dim_in
,
dim_out
*
2
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
...
...
@@ -62,19 +62,19 @@ class GEGLU(nn.Module):
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
,
dtype
=
None
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
dim
,
inner_dim
,
dtype
=
dtype
),
comfy
.
ops
.
Linear
(
dim
,
inner_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
GELU
()
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
,
dtype
=
dtype
)
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
comfy
.
ops
.
Linear
(
inner_dim
,
dim_out
,
dtype
=
dtype
)
comfy
.
ops
.
Linear
(
inner_dim
,
dim_out
,
dtype
=
dtype
,
device
=
device
)
)
def
forward
(
self
,
x
):
...
...
@@ -90,8 +90,8 @@ def zero_module(module):
return
module
def
Normalize
(
in_channels
,
dtype
=
None
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
,
dtype
=
dtype
)
def
Normalize
(
in_channels
,
dtype
=
None
,
device
=
None
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
,
dtype
=
dtype
,
device
=
device
)
class
SpatialSelfAttention
(
nn
.
Module
):
...
...
@@ -148,7 +148,7 @@ class SpatialSelfAttention(nn.Module):
class
CrossAttentionBirchSan
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
...
...
@@ -156,12 +156,12 @@ class CrossAttentionBirchSan(nn.Module):
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
)
)
...
...
@@ -245,7 +245,7 @@ class CrossAttentionBirchSan(nn.Module):
class
CrossAttentionDoggettx
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
...
...
@@ -253,12 +253,12 @@ class CrossAttentionDoggettx(nn.Module):
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
)
)
...
...
@@ -343,7 +343,7 @@ class CrossAttentionDoggettx(nn.Module):
return
self
.
to_out
(
r2
)
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
...
...
@@ -351,12 +351,12 @@ class CrossAttention(nn.Module):
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
)
)
...
...
@@ -399,7 +399,7 @@ class CrossAttention(nn.Module):
class
MemoryEfficientCrossAttention
(
nn
.
Module
):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
,
dtype
=
None
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
print
(
f
"Setting up
{
self
.
__class__
.
__name__
}
. Query dim is
{
query_dim
}
, context_dim is
{
context_dim
}
and using "
f
"
{
heads
}
heads."
)
...
...
@@ -409,11 +409,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
))
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
))
self
.
attention_op
:
Optional
[
Any
]
=
None
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
...
...
@@ -450,7 +450,7 @@ class MemoryEfficientCrossAttention(nn.Module):
return
self
.
to_out
(
out
)
class
CrossAttentionPytorch
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
...
...
@@ -458,11 +458,11 @@ class CrossAttentionPytorch(nn.Module):
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
))
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
))
self
.
attention_op
:
Optional
[
Any
]
=
None
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
...
...
@@ -508,17 +508,17 @@ else:
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
disable_self_attn
=
False
,
dtype
=
None
):
disable_self_attn
=
False
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
self
.
disable_self_attn
=
disable_self_attn
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
if
self
.
disable_self_attn
else
None
,
dtype
=
dtype
)
# is a self-attention if not self.disable_self_attn
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
,
dtype
=
dtype
)
context_dim
=
context_dim
if
self
.
disable_self_attn
else
None
,
dtype
=
dtype
,
device
=
device
)
# is a self-attention if not self.disable_self_attn
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
,
dtype
=
dtype
,
device
=
device
)
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
dtype
=
dtype
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
)
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
dtype
=
dtype
,
device
=
device
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
checkpoint
=
checkpoint
self
.
n_heads
=
n_heads
self
.
d_head
=
d_head
...
...
@@ -648,34 +648,34 @@ class SpatialTransformer(nn.Module):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.
,
context_dim
=
None
,
disable_self_attn
=
False
,
use_linear
=
False
,
use_checkpoint
=
True
,
dtype
=
None
):
use_checkpoint
=
True
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
if
exists
(
context_dim
)
and
not
isinstance
(
context_dim
,
list
):
context_dim
=
[
context_dim
]
*
depth
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
self
.
norm
=
Normalize
(
in_channels
,
dtype
=
dtype
)
self
.
norm
=
Normalize
(
in_channels
,
dtype
=
dtype
,
device
=
device
)
if
not
use_linear
:
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
dtype
=
dtype
)
padding
=
0
,
dtype
=
dtype
,
device
=
device
)
else
:
self
.
proj_in
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
)
self
.
proj_in
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
[
d
],
disable_self_attn
=
disable_self_attn
,
checkpoint
=
use_checkpoint
,
dtype
=
dtype
)
disable_self_attn
=
disable_self_attn
,
checkpoint
=
use_checkpoint
,
dtype
=
dtype
,
device
=
device
)
for
d
in
range
(
depth
)]
)
if
not
use_linear
:
self
.
proj_out
=
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
dtype
=
dtype
)
padding
=
0
,
dtype
=
dtype
,
device
=
device
)
else
:
self
.
proj_out
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
)
self
.
proj_out
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
use_linear
=
use_linear
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}):
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
4b957a00
...
...
@@ -111,14 +111,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
):
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
,
dtype
=
dtype
)
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
,
output_shape
=
None
):
assert
x
.
shape
[
1
]
==
self
.
channels
...
...
@@ -160,7 +160,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
):
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
...
...
@@ -169,7 +169,7 @@ class Downsample(nn.Module):
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
,
dtype
=
dtype
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
,
dtype
=
dtype
,
device
=
device
)
else
:
assert
self
.
channels
==
self
.
out_channels
...
...
@@ -208,7 +208,8 @@ class ResBlock(TimestepBlock):
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
dtype
=
None
dtype
=
None
,
device
=
None
,
):
super
().
__init__
()
self
.
channels
=
channels
...
...
@@ -220,19 +221,19 @@ class ResBlock(TimestepBlock):
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
channels
,
dtype
=
dtype
),
nn
.
GroupNorm
(
32
,
channels
,
dtype
=
dtype
,
device
=
device
),
nn
.
SiLU
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
,
device
=
device
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
,
device
=
device
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
,
device
=
device
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
,
device
=
device
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
...
...
@@ -240,15 +241,15 @@ class ResBlock(TimestepBlock):
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
dtype
=
dtype
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
dtype
=
dtype
,
device
=
device
),
)
self
.
out_layers
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
self
.
out_channels
,
dtype
=
dtype
),
nn
.
GroupNorm
(
32
,
self
.
out_channels
,
dtype
=
dtype
,
device
=
device
),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
)
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
)
),
)
...
...
@@ -256,10 +257,10 @@ class ResBlock(TimestepBlock):
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
,
dtype
=
dtype
)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
,
emb
):
"""
...
...
@@ -503,6 +504,7 @@ class UNetModel(nn.Module):
use_linear_in_transformer
=
False
,
adm_in_channels
=
None
,
transformer_depth_middle
=
None
,
device
=
None
,
):
super
().
__init__
()
if
use_spatial_transformer
:
...
...
@@ -564,9 +566,9 @@ class UNetModel(nn.Module):
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
),
linear
(
model_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
,
dtype
=
self
.
dtype
),
linear
(
time_embed_dim
,
time_embed_dim
,
dtype
=
self
.
dtype
,
device
=
device
),
)
if
self
.
num_classes
is
not
None
:
...
...
@@ -579,9 +581,9 @@ class UNetModel(nn.Module):
assert
adm_in_channels
is
not
None
self
.
label_emb
=
nn
.
Sequential
(
nn
.
Sequential
(
linear
(
adm_in_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
),
linear
(
adm_in_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
,
dtype
=
self
.
dtype
),
linear
(
time_embed_dim
,
time_embed_dim
,
dtype
=
self
.
dtype
,
device
=
device
),
)
)
else
:
...
...
@@ -590,7 +592,7 @@ class UNetModel(nn.Module):
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
)
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
,
device
=
device
)
)
]
)
...
...
@@ -609,7 +611,8 @@ class UNetModel(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
)
]
ch
=
mult
*
model_channels
...
...
@@ -638,7 +641,7 @@ class UNetModel(nn.Module):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
[
level
],
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
...
@@ -657,11 +660,12 @@ class UNetModel(nn.Module):
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
,
device
=
device
)
)
)
...
...
@@ -686,7 +690,8 @@ class UNetModel(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
),
AttentionBlock
(
ch
,
...
...
@@ -697,7 +702,7 @@ class UNetModel(nn.Module):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
),
ResBlock
(
ch
,
...
...
@@ -706,7 +711,8 @@ class UNetModel(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
),
)
self
.
_feature_size
+=
ch
...
...
@@ -724,7 +730,8 @@ class UNetModel(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
)
]
ch
=
model_channels
*
mult
...
...
@@ -753,7 +760,7 @@ class UNetModel(nn.Module):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
[
level
],
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
)
)
if
level
and
i
==
self
.
num_res_blocks
[
level
]:
...
...
@@ -768,24 +775,25 @@ class UNetModel(nn.Module):
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
up
=
True
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
)
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
)
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
,
device
=
device
)
)
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
out
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
ch
,
dtype
=
self
.
dtype
),
nn
.
GroupNorm
(
32
,
ch
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
zero_module
(
conv_nd
(
dims
,
model_channels
,
out_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
)),
zero_module
(
conv_nd
(
dims
,
model_channels
,
out_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
,
device
=
device
)),
)
if
self
.
predict_codebook_ids
:
self
.
id_predictor
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
ch
,
dtype
=
self
.
dtype
),
conv_nd
(
dims
,
model_channels
,
n_embed
,
1
),
nn
.
GroupNorm
(
32
,
ch
,
dtype
=
self
.
dtype
,
device
=
device
),
conv_nd
(
dims
,
model_channels
,
n_embed
,
1
,
dtype
=
self
.
dtype
,
device
=
device
),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
...
...
comfy/model_base.py
View file @
4b957a00
...
...
@@ -12,14 +12,14 @@ class ModelType(Enum):
V_PREDICTION
=
2
class
BaseModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
,
device
=
None
):
super
().
__init__
()
unet_config
=
model_config
.
unet_config
self
.
latent_format
=
model_config
.
latent_format
self
.
model_config
=
model_config
self
.
register_schedule
(
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
linear_start
=
0.00085
,
linear_end
=
0.012
,
cosine_s
=
8e-3
)
self
.
diffusion_model
=
UNetModel
(
**
unet_config
)
self
.
diffusion_model
=
UNetModel
(
**
unet_config
,
device
=
device
)
self
.
model_type
=
model_type
self
.
adm_channels
=
unet_config
.
get
(
"adm_in_channels"
,
None
)
if
self
.
adm_channels
is
None
:
...
...
@@ -107,8 +107,8 @@ class BaseModel(torch.nn.Module):
class
SD21UNCLIP
(
BaseModel
):
def
__init__
(
self
,
model_config
,
noise_aug_config
,
model_type
=
ModelType
.
V_PREDICTION
):
super
().
__init__
(
model_config
,
model_type
)
def
__init__
(
self
,
model_config
,
noise_aug_config
,
model_type
=
ModelType
.
V_PREDICTION
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
noise_augmentor
=
CLIPEmbeddingNoiseAugmentation
(
**
noise_aug_config
)
def
encode_adm
(
self
,
**
kwargs
):
...
...
@@ -143,13 +143,13 @@ class SD21UNCLIP(BaseModel):
return
adm_out
class
SDInpaint
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
super
().
__init__
(
model_config
,
model_type
)
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
concat_keys
=
(
"mask"
,
"masked_image"
)
class
SDXLRefiner
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
super
().
__init__
(
model_config
,
model_type
)
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
...
...
@@ -174,8 +174,8 @@ class SDXLRefiner(BaseModel):
return
torch
.
cat
((
clip_pooled
.
to
(
flat
.
device
),
flat
),
dim
=
1
)
class
SDXL
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
super
().
__init__
(
model_config
,
model_type
)
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
...
...
comfy/sd.py
View file @
4b957a00
...
...
@@ -1169,8 +1169,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clipvision
=
clip_vision
.
load_clipvision_from_sd
(
sd
,
model_config
.
clip_vision_prefix
,
True
)
offload_device
=
model_management
.
unet_offload_device
()
model
=
model_config
.
get_model
(
sd
,
"model.diffusion_model."
)
model
=
model
.
to
(
offload_device
)
model
=
model_config
.
get_model
(
sd
,
"model.diffusion_model."
,
device
=
offload_device
)
model
.
load_model_weights
(
sd
,
"model.diffusion_model."
)
if
output_vae
:
...
...
comfy/supported_models.py
View file @
4b957a00
...
...
@@ -109,8 +109,8 @@ class SDXLRefiner(supported_models_base.BASE):
latent_format
=
latent_formats
.
SDXL
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
return
model_base
.
SDXLRefiner
(
self
)
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
return
model_base
.
SDXLRefiner
(
self
,
device
=
device
)
def
process_clip_state_dict
(
self
,
state_dict
):
keys_to_replace
=
{}
...
...
@@ -152,8 +152,8 @@ class SDXL(supported_models_base.BASE):
else
:
return
model_base
.
ModelType
.
EPS
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
return
model_base
.
SDXL
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
return
model_base
.
SDXL
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
)
,
device
=
device
)
def
process_clip_state_dict
(
self
,
state_dict
):
keys_to_replace
=
{}
...
...
comfy/supported_models_base.py
View file @
4b957a00
...
...
@@ -53,13 +53,13 @@ class BASE:
for
x
in
self
.
unet_extra_config
:
self
.
unet_config
[
x
]
=
self
.
unet_extra_config
[
x
]
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
if
self
.
inpaint_model
():
return
model_base
.
SDInpaint
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
return
model_base
.
SDInpaint
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
)
,
device
=
device
)
elif
self
.
noise_aug_config
is
not
None
:
return
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
return
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
)
,
device
=
device
)
else
:
return
model_base
.
BaseModel
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
return
model_base
.
BaseModel
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
)
,
device
=
device
)
def
process_clip_state_dict
(
self
,
state_dict
):
return
state_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment