Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
4b957a00
Commit
4b957a00
authored
Jul 29, 2023
by
comfyanonymous
Browse files
Initialize the unet directly on the target device.
parent
ad5866b0
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
110 additions
and
103 deletions
+110
-103
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+47
-47
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+44
-36
comfy/model_base.py
comfy/model_base.py
+10
-10
comfy/sd.py
comfy/sd.py
+1
-2
comfy/supported_models.py
comfy/supported_models.py
+4
-4
comfy/supported_models_base.py
comfy/supported_models_base.py
+4
-4
No files found.
comfy/ldm/modules/attention.py
View file @
4b957a00
...
@@ -52,9 +52,9 @@ def init_(tensor):
...
@@ -52,9 +52,9 @@ def init_(tensor):
# feedforward
# feedforward
class
GEGLU
(
nn
.
Module
):
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
,
dtype
=
None
):
def
__init__
(
self
,
dim_in
,
dim_out
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
proj
=
comfy
.
ops
.
Linear
(
dim_in
,
dim_out
*
2
,
dtype
=
dtype
)
self
.
proj
=
comfy
.
ops
.
Linear
(
dim_in
,
dim_out
*
2
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
...
@@ -62,19 +62,19 @@ class GEGLU(nn.Module):
...
@@ -62,19 +62,19 @@ class GEGLU(nn.Module):
class
FeedForward
(
nn
.
Module
):
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
,
dtype
=
None
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
project_in
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
dim
,
inner_dim
,
dtype
=
dtype
),
comfy
.
ops
.
Linear
(
dim
,
inner_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
GELU
()
nn
.
GELU
()
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
,
dtype
=
dtype
)
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
net
=
nn
.
Sequential
(
self
.
net
=
nn
.
Sequential
(
project_in
,
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Dropout
(
dropout
),
comfy
.
ops
.
Linear
(
inner_dim
,
dim_out
,
dtype
=
dtype
)
comfy
.
ops
.
Linear
(
inner_dim
,
dim_out
,
dtype
=
dtype
,
device
=
device
)
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -90,8 +90,8 @@ def zero_module(module):
...
@@ -90,8 +90,8 @@ def zero_module(module):
return
module
return
module
def
Normalize
(
in_channels
,
dtype
=
None
):
def
Normalize
(
in_channels
,
dtype
=
None
,
device
=
None
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
,
dtype
=
dtype
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
,
dtype
=
dtype
,
device
=
device
)
class
SpatialSelfAttention
(
nn
.
Module
):
class
SpatialSelfAttention
(
nn
.
Module
):
...
@@ -148,7 +148,7 @@ class SpatialSelfAttention(nn.Module):
...
@@ -148,7 +148,7 @@ class SpatialSelfAttention(nn.Module):
class
CrossAttentionBirchSan
(
nn
.
Module
):
class
CrossAttentionBirchSan
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
dim_head
*
heads
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
context_dim
=
default
(
context_dim
,
query_dim
)
...
@@ -156,12 +156,12 @@ class CrossAttentionBirchSan(nn.Module):
...
@@ -156,12 +156,12 @@ class CrossAttentionBirchSan(nn.Module):
self
.
scale
=
dim_head
**
-
0.5
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
heads
=
heads
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_out
=
nn
.
Sequential
(
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
)
nn
.
Dropout
(
dropout
)
)
)
...
@@ -245,7 +245,7 @@ class CrossAttentionBirchSan(nn.Module):
...
@@ -245,7 +245,7 @@ class CrossAttentionBirchSan(nn.Module):
class
CrossAttentionDoggettx
(
nn
.
Module
):
class
CrossAttentionDoggettx
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
dim_head
*
heads
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
context_dim
=
default
(
context_dim
,
query_dim
)
...
@@ -253,12 +253,12 @@ class CrossAttentionDoggettx(nn.Module):
...
@@ -253,12 +253,12 @@ class CrossAttentionDoggettx(nn.Module):
self
.
scale
=
dim_head
**
-
0.5
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
heads
=
heads
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_out
=
nn
.
Sequential
(
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
)
nn
.
Dropout
(
dropout
)
)
)
...
@@ -343,7 +343,7 @@ class CrossAttentionDoggettx(nn.Module):
...
@@ -343,7 +343,7 @@ class CrossAttentionDoggettx(nn.Module):
return
self
.
to_out
(
r2
)
return
self
.
to_out
(
r2
)
class
CrossAttention
(
nn
.
Module
):
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
dim_head
*
heads
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
context_dim
=
default
(
context_dim
,
query_dim
)
...
@@ -351,12 +351,12 @@ class CrossAttention(nn.Module):
...
@@ -351,12 +351,12 @@ class CrossAttention(nn.Module):
self
.
scale
=
dim_head
**
-
0.5
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
heads
=
heads
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_out
=
nn
.
Sequential
(
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
)
nn
.
Dropout
(
dropout
)
)
)
...
@@ -399,7 +399,7 @@ class CrossAttention(nn.Module):
...
@@ -399,7 +399,7 @@ class CrossAttention(nn.Module):
class
MemoryEfficientCrossAttention
(
nn
.
Module
):
class
MemoryEfficientCrossAttention
(
nn
.
Module
):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
,
dtype
=
None
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
print
(
f
"Setting up
{
self
.
__class__
.
__name__
}
. Query dim is
{
query_dim
}
, context_dim is
{
context_dim
}
and using "
print
(
f
"Setting up
{
self
.
__class__
.
__name__
}
. Query dim is
{
query_dim
}
, context_dim is
{
context_dim
}
and using "
f
"
{
heads
}
heads."
)
f
"
{
heads
}
heads."
)
...
@@ -409,11 +409,11 @@ class MemoryEfficientCrossAttention(nn.Module):
...
@@ -409,11 +409,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self
.
heads
=
heads
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
dim_head
=
dim_head
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
))
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
))
self
.
attention_op
:
Optional
[
Any
]
=
None
self
.
attention_op
:
Optional
[
Any
]
=
None
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
...
@@ -450,7 +450,7 @@ class MemoryEfficientCrossAttention(nn.Module):
...
@@ -450,7 +450,7 @@ class MemoryEfficientCrossAttention(nn.Module):
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
class
CrossAttentionPytorch
(
nn
.
Module
):
class
CrossAttentionPytorch
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
dim_head
*
heads
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
context_dim
=
default
(
context_dim
,
query_dim
)
...
@@ -458,11 +458,11 @@ class CrossAttentionPytorch(nn.Module):
...
@@ -458,11 +458,11 @@ class CrossAttentionPytorch(nn.Module):
self
.
heads
=
heads
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
dim_head
=
dim_head
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_q
=
comfy
.
ops
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_k
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
to_v
=
comfy
.
ops
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
),
nn
.
Dropout
(
dropout
))
self
.
to_out
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
inner_dim
,
query_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
))
self
.
attention_op
:
Optional
[
Any
]
=
None
self
.
attention_op
:
Optional
[
Any
]
=
None
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
...
@@ -508,17 +508,17 @@ else:
...
@@ -508,17 +508,17 @@ else:
class
BasicTransformerBlock
(
nn
.
Module
):
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
disable_self_attn
=
False
,
dtype
=
None
):
disable_self_attn
=
False
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
disable_self_attn
=
disable_self_attn
self
.
disable_self_attn
=
disable_self_attn
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
if
self
.
disable_self_attn
else
None
,
dtype
=
dtype
)
# is a self-attention if not self.disable_self_attn
context_dim
=
context_dim
if
self
.
disable_self_attn
else
None
,
dtype
=
dtype
,
device
=
device
)
# is a self-attention if not self.disable_self_attn
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
,
dtype
=
dtype
)
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
,
dtype
=
dtype
,
device
=
device
)
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
dtype
=
dtype
)
# is self-attn if context is none
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
dtype
=
dtype
,
device
=
device
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
norm1
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
checkpoint
=
checkpoint
self
.
checkpoint
=
checkpoint
self
.
n_heads
=
n_heads
self
.
n_heads
=
n_heads
self
.
d_head
=
d_head
self
.
d_head
=
d_head
...
@@ -648,34 +648,34 @@ class SpatialTransformer(nn.Module):
...
@@ -648,34 +648,34 @@ class SpatialTransformer(nn.Module):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.
,
context_dim
=
None
,
depth
=
1
,
dropout
=
0.
,
context_dim
=
None
,
disable_self_attn
=
False
,
use_linear
=
False
,
disable_self_attn
=
False
,
use_linear
=
False
,
use_checkpoint
=
True
,
dtype
=
None
):
use_checkpoint
=
True
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
if
exists
(
context_dim
)
and
not
isinstance
(
context_dim
,
list
):
if
exists
(
context_dim
)
and
not
isinstance
(
context_dim
,
list
):
context_dim
=
[
context_dim
]
*
depth
context_dim
=
[
context_dim
]
*
depth
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
inner_dim
=
n_heads
*
d_head
self
.
norm
=
Normalize
(
in_channels
,
dtype
=
dtype
)
self
.
norm
=
Normalize
(
in_channels
,
dtype
=
dtype
,
device
=
device
)
if
not
use_linear
:
if
not
use_linear
:
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
inner_dim
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
,
dtype
=
dtype
)
padding
=
0
,
dtype
=
dtype
,
device
=
device
)
else
:
else
:
self
.
proj_in
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
)
self
.
proj_in
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
transformer_blocks
=
nn
.
ModuleList
(
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
[
d
],
[
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
[
d
],
disable_self_attn
=
disable_self_attn
,
checkpoint
=
use_checkpoint
,
dtype
=
dtype
)
disable_self_attn
=
disable_self_attn
,
checkpoint
=
use_checkpoint
,
dtype
=
dtype
,
device
=
device
)
for
d
in
range
(
depth
)]
for
d
in
range
(
depth
)]
)
)
if
not
use_linear
:
if
not
use_linear
:
self
.
proj_out
=
nn
.
Conv2d
(
inner_dim
,
in_channels
,
self
.
proj_out
=
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
,
dtype
=
dtype
)
padding
=
0
,
dtype
=
dtype
,
device
=
device
)
else
:
else
:
self
.
proj_out
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
)
self
.
proj_out
=
comfy
.
ops
.
Linear
(
in_channels
,
inner_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
use_linear
=
use_linear
self
.
use_linear
=
use_linear
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}):
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}):
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
4b957a00
...
@@ -111,14 +111,14 @@ class Upsample(nn.Module):
...
@@ -111,14 +111,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
upsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
):
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_conv
=
use_conv
self
.
dims
=
dims
self
.
dims
=
dims
if
use_conv
:
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
,
dtype
=
dtype
)
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
,
output_shape
=
None
):
def
forward
(
self
,
x
,
output_shape
=
None
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
x
.
shape
[
1
]
==
self
.
channels
...
@@ -160,7 +160,7 @@ class Downsample(nn.Module):
...
@@ -160,7 +160,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
downsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
):
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
...
@@ -169,7 +169,7 @@ class Downsample(nn.Module):
...
@@ -169,7 +169,7 @@ class Downsample(nn.Module):
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
if
use_conv
:
self
.
op
=
conv_nd
(
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
,
dtype
=
dtype
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
,
dtype
=
dtype
,
device
=
device
)
)
else
:
else
:
assert
self
.
channels
==
self
.
out_channels
assert
self
.
channels
==
self
.
out_channels
...
@@ -208,7 +208,8 @@ class ResBlock(TimestepBlock):
...
@@ -208,7 +208,8 @@ class ResBlock(TimestepBlock):
use_checkpoint
=
False
,
use_checkpoint
=
False
,
up
=
False
,
up
=
False
,
down
=
False
,
down
=
False
,
dtype
=
None
dtype
=
None
,
device
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -220,19 +221,19 @@ class ResBlock(TimestepBlock):
...
@@ -220,19 +221,19 @@ class ResBlock(TimestepBlock):
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
self
.
in_layers
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
channels
,
dtype
=
dtype
),
nn
.
GroupNorm
(
32
,
channels
,
dtype
=
dtype
,
device
=
device
),
nn
.
SiLU
(),
nn
.
SiLU
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
),
)
)
self
.
updown
=
up
or
down
self
.
updown
=
up
or
down
if
up
:
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
,
device
=
device
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
,
device
=
device
)
elif
down
:
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
,
device
=
device
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
,
device
=
device
)
else
:
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
...
@@ -240,15 +241,15 @@ class ResBlock(TimestepBlock):
...
@@ -240,15 +241,15 @@ class ResBlock(TimestepBlock):
nn
.
SiLU
(),
nn
.
SiLU
(),
linear
(
linear
(
emb_channels
,
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
dtype
=
dtype
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
dtype
=
dtype
,
device
=
device
),
),
)
)
self
.
out_layers
=
nn
.
Sequential
(
self
.
out_layers
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
self
.
out_channels
,
dtype
=
dtype
),
nn
.
GroupNorm
(
32
,
self
.
out_channels
,
dtype
=
dtype
,
device
=
device
),
nn
.
SiLU
(),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
)
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
)
),
),
)
)
...
@@ -256,10 +257,10 @@ class ResBlock(TimestepBlock):
...
@@ -256,10 +257,10 @@ class ResBlock(TimestepBlock):
self
.
skip_connection
=
nn
.
Identity
()
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
)
)
else
:
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
,
dtype
=
dtype
)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
,
emb
):
def
forward
(
self
,
x
,
emb
):
"""
"""
...
@@ -503,6 +504,7 @@ class UNetModel(nn.Module):
...
@@ -503,6 +504,7 @@ class UNetModel(nn.Module):
use_linear_in_transformer
=
False
,
use_linear_in_transformer
=
False
,
adm_in_channels
=
None
,
adm_in_channels
=
None
,
transformer_depth_middle
=
None
,
transformer_depth_middle
=
None
,
device
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
if
use_spatial_transformer
:
if
use_spatial_transformer
:
...
@@ -564,9 +566,9 @@ class UNetModel(nn.Module):
...
@@ -564,9 +566,9 @@ class UNetModel(nn.Module):
time_embed_dim
=
model_channels
*
4
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
nn
.
Sequential
(
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
),
linear
(
model_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
,
dtype
=
self
.
dtype
),
linear
(
time_embed_dim
,
time_embed_dim
,
dtype
=
self
.
dtype
,
device
=
device
),
)
)
if
self
.
num_classes
is
not
None
:
if
self
.
num_classes
is
not
None
:
...
@@ -579,9 +581,9 @@ class UNetModel(nn.Module):
...
@@ -579,9 +581,9 @@ class UNetModel(nn.Module):
assert
adm_in_channels
is
not
None
assert
adm_in_channels
is
not
None
self
.
label_emb
=
nn
.
Sequential
(
self
.
label_emb
=
nn
.
Sequential
(
nn
.
Sequential
(
nn
.
Sequential
(
linear
(
adm_in_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
),
linear
(
adm_in_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
,
dtype
=
self
.
dtype
),
linear
(
time_embed_dim
,
time_embed_dim
,
dtype
=
self
.
dtype
,
device
=
device
),
)
)
)
)
else
:
else
:
...
@@ -590,7 +592,7 @@ class UNetModel(nn.Module):
...
@@ -590,7 +592,7 @@ class UNetModel(nn.Module):
self
.
input_blocks
=
nn
.
ModuleList
(
self
.
input_blocks
=
nn
.
ModuleList
(
[
[
TimestepEmbedSequential
(
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
)
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
,
device
=
device
)
)
)
]
]
)
)
...
@@ -609,7 +611,8 @@ class UNetModel(nn.Module):
...
@@ -609,7 +611,8 @@ class UNetModel(nn.Module):
dims
=
dims
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
)
)
]
]
ch
=
mult
*
model_channels
ch
=
mult
*
model_channels
...
@@ -638,7 +641,7 @@ class UNetModel(nn.Module):
...
@@ -638,7 +641,7 @@ class UNetModel(nn.Module):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
[
level
],
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
[
level
],
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
)
)
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
@@ -657,11 +660,12 @@ class UNetModel(nn.Module):
...
@@ -657,11 +660,12 @@ class UNetModel(nn.Module):
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
down
=
True
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
)
)
if
resblock_updown
if
resblock_updown
else
Downsample
(
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
,
device
=
device
)
)
)
)
)
)
...
@@ -686,7 +690,8 @@ class UNetModel(nn.Module):
...
@@ -686,7 +690,8 @@ class UNetModel(nn.Module):
dims
=
dims
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
),
),
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
...
@@ -697,7 +702,7 @@ class UNetModel(nn.Module):
...
@@ -697,7 +702,7 @@ class UNetModel(nn.Module):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
# always uses a self-attn
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
),
),
ResBlock
(
ResBlock
(
ch
,
ch
,
...
@@ -706,7 +711,8 @@ class UNetModel(nn.Module):
...
@@ -706,7 +711,8 @@ class UNetModel(nn.Module):
dims
=
dims
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
),
),
)
)
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
@@ -724,7 +730,8 @@ class UNetModel(nn.Module):
...
@@ -724,7 +730,8 @@ class UNetModel(nn.Module):
dims
=
dims
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
)
)
]
]
ch
=
model_channels
*
mult
ch
=
model_channels
*
mult
...
@@ -753,7 +760,7 @@ class UNetModel(nn.Module):
...
@@ -753,7 +760,7 @@ class UNetModel(nn.Module):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
[
level
],
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
[
level
],
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
)
)
)
)
if
level
and
i
==
self
.
num_res_blocks
[
level
]:
if
level
and
i
==
self
.
num_res_blocks
[
level
]:
...
@@ -768,24 +775,25 @@ class UNetModel(nn.Module):
...
@@ -768,24 +775,25 @@ class UNetModel(nn.Module):
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
up
=
True
,
up
=
True
,
dtype
=
self
.
dtype
dtype
=
self
.
dtype
,
device
=
device
,
)
)
if
resblock_updown
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
)
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
,
device
=
device
)
)
)
ds
//=
2
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
out
=
nn
.
Sequential
(
self
.
out
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
ch
,
dtype
=
self
.
dtype
),
nn
.
GroupNorm
(
32
,
ch
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
nn
.
SiLU
(),
zero_module
(
conv_nd
(
dims
,
model_channels
,
out_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
)),
zero_module
(
conv_nd
(
dims
,
model_channels
,
out_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
,
device
=
device
)),
)
)
if
self
.
predict_codebook_ids
:
if
self
.
predict_codebook_ids
:
self
.
id_predictor
=
nn
.
Sequential
(
self
.
id_predictor
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
ch
,
dtype
=
self
.
dtype
),
nn
.
GroupNorm
(
32
,
ch
,
dtype
=
self
.
dtype
,
device
=
device
),
conv_nd
(
dims
,
model_channels
,
n_embed
,
1
),
conv_nd
(
dims
,
model_channels
,
n_embed
,
1
,
dtype
=
self
.
dtype
,
device
=
device
),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
)
...
...
comfy/model_base.py
View file @
4b957a00
...
@@ -12,14 +12,14 @@ class ModelType(Enum):
...
@@ -12,14 +12,14 @@ class ModelType(Enum):
V_PREDICTION
=
2
V_PREDICTION
=
2
class
BaseModel
(
torch
.
nn
.
Module
):
class
BaseModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
unet_config
=
model_config
.
unet_config
unet_config
=
model_config
.
unet_config
self
.
latent_format
=
model_config
.
latent_format
self
.
latent_format
=
model_config
.
latent_format
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
register_schedule
(
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
linear_start
=
0.00085
,
linear_end
=
0.012
,
cosine_s
=
8e-3
)
self
.
register_schedule
(
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
linear_start
=
0.00085
,
linear_end
=
0.012
,
cosine_s
=
8e-3
)
self
.
diffusion_model
=
UNetModel
(
**
unet_config
)
self
.
diffusion_model
=
UNetModel
(
**
unet_config
,
device
=
device
)
self
.
model_type
=
model_type
self
.
model_type
=
model_type
self
.
adm_channels
=
unet_config
.
get
(
"adm_in_channels"
,
None
)
self
.
adm_channels
=
unet_config
.
get
(
"adm_in_channels"
,
None
)
if
self
.
adm_channels
is
None
:
if
self
.
adm_channels
is
None
:
...
@@ -107,8 +107,8 @@ class BaseModel(torch.nn.Module):
...
@@ -107,8 +107,8 @@ class BaseModel(torch.nn.Module):
class
SD21UNCLIP
(
BaseModel
):
class
SD21UNCLIP
(
BaseModel
):
def
__init__
(
self
,
model_config
,
noise_aug_config
,
model_type
=
ModelType
.
V_PREDICTION
):
def
__init__
(
self
,
model_config
,
noise_aug_config
,
model_type
=
ModelType
.
V_PREDICTION
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
)
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
noise_augmentor
=
CLIPEmbeddingNoiseAugmentation
(
**
noise_aug_config
)
self
.
noise_augmentor
=
CLIPEmbeddingNoiseAugmentation
(
**
noise_aug_config
)
def
encode_adm
(
self
,
**
kwargs
):
def
encode_adm
(
self
,
**
kwargs
):
...
@@ -143,13 +143,13 @@ class SD21UNCLIP(BaseModel):
...
@@ -143,13 +143,13 @@ class SD21UNCLIP(BaseModel):
return
adm_out
return
adm_out
class
SDInpaint
(
BaseModel
):
class
SDInpaint
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
)
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
concat_keys
=
(
"mask"
,
"masked_image"
)
self
.
concat_keys
=
(
"mask"
,
"masked_image"
)
class
SDXLRefiner
(
BaseModel
):
class
SDXLRefiner
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
)
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
embedder
=
Timestep
(
256
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
def
encode_adm
(
self
,
**
kwargs
):
...
@@ -174,8 +174,8 @@ class SDXLRefiner(BaseModel):
...
@@ -174,8 +174,8 @@ class SDXLRefiner(BaseModel):
return
torch
.
cat
((
clip_pooled
.
to
(
flat
.
device
),
flat
),
dim
=
1
)
return
torch
.
cat
((
clip_pooled
.
to
(
flat
.
device
),
flat
),
dim
=
1
)
class
SDXL
(
BaseModel
):
class
SDXL
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
)
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
embedder
=
Timestep
(
256
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
def
encode_adm
(
self
,
**
kwargs
):
...
...
comfy/sd.py
View file @
4b957a00
...
@@ -1169,8 +1169,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
...
@@ -1169,8 +1169,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clipvision
=
clip_vision
.
load_clipvision_from_sd
(
sd
,
model_config
.
clip_vision_prefix
,
True
)
clipvision
=
clip_vision
.
load_clipvision_from_sd
(
sd
,
model_config
.
clip_vision_prefix
,
True
)
offload_device
=
model_management
.
unet_offload_device
()
offload_device
=
model_management
.
unet_offload_device
()
model
=
model_config
.
get_model
(
sd
,
"model.diffusion_model."
)
model
=
model_config
.
get_model
(
sd
,
"model.diffusion_model."
,
device
=
offload_device
)
model
=
model
.
to
(
offload_device
)
model
.
load_model_weights
(
sd
,
"model.diffusion_model."
)
model
.
load_model_weights
(
sd
,
"model.diffusion_model."
)
if
output_vae
:
if
output_vae
:
...
...
comfy/supported_models.py
View file @
4b957a00
...
@@ -109,8 +109,8 @@ class SDXLRefiner(supported_models_base.BASE):
...
@@ -109,8 +109,8 @@ class SDXLRefiner(supported_models_base.BASE):
latent_format
=
latent_formats
.
SDXL
latent_format
=
latent_formats
.
SDXL
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
return
model_base
.
SDXLRefiner
(
self
)
return
model_base
.
SDXLRefiner
(
self
,
device
=
device
)
def
process_clip_state_dict
(
self
,
state_dict
):
def
process_clip_state_dict
(
self
,
state_dict
):
keys_to_replace
=
{}
keys_to_replace
=
{}
...
@@ -152,8 +152,8 @@ class SDXL(supported_models_base.BASE):
...
@@ -152,8 +152,8 @@ class SDXL(supported_models_base.BASE):
else
:
else
:
return
model_base
.
ModelType
.
EPS
return
model_base
.
ModelType
.
EPS
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
return
model_base
.
SDXL
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
return
model_base
.
SDXL
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
)
,
device
=
device
)
def
process_clip_state_dict
(
self
,
state_dict
):
def
process_clip_state_dict
(
self
,
state_dict
):
keys_to_replace
=
{}
keys_to_replace
=
{}
...
...
comfy/supported_models_base.py
View file @
4b957a00
...
@@ -53,13 +53,13 @@ class BASE:
...
@@ -53,13 +53,13 @@ class BASE:
for
x
in
self
.
unet_extra_config
:
for
x
in
self
.
unet_extra_config
:
self
.
unet_config
[
x
]
=
self
.
unet_extra_config
[
x
]
self
.
unet_config
[
x
]
=
self
.
unet_extra_config
[
x
]
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
if
self
.
inpaint_model
():
if
self
.
inpaint_model
():
return
model_base
.
SDInpaint
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
return
model_base
.
SDInpaint
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
)
,
device
=
device
)
elif
self
.
noise_aug_config
is
not
None
:
elif
self
.
noise_aug_config
is
not
None
:
return
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
return
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
)
,
device
=
device
)
else
:
else
:
return
model_base
.
BaseModel
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
))
return
model_base
.
BaseModel
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
)
,
device
=
device
)
def
process_clip_state_dict
(
self
,
state_dict
):
def
process_clip_state_dict
(
self
,
state_dict
):
return
state_dict
return
state_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment