Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
85fde89d
Commit
85fde89d
authored
Aug 22, 2023
by
comfyanonymous
Browse files
T2I adapter SDXL.
parent
f2a7cc91
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
10 deletions
+48
-10
comfy/sd.py
comfy/sd.py
+17
-4
comfy/t2i_adapter/adapter.py
comfy/t2i_adapter/adapter.py
+31
-6
No files found.
comfy/sd.py
View file @
85fde89d
...
@@ -1128,7 +1128,11 @@ class T2IAdapter(ControlBase):
...
@@ -1128,7 +1128,11 @@ class T2IAdapter(ControlBase):
self
.
t2i_model
.
cpu
()
self
.
t2i_model
.
cpu
()
control_input
=
list
(
map
(
lambda
a
:
None
if
a
is
None
else
a
.
clone
(),
self
.
control_input
))
control_input
=
list
(
map
(
lambda
a
:
None
if
a
is
None
else
a
.
clone
(),
self
.
control_input
))
return
self
.
control_merge
(
control_input
,
None
,
control_prev
,
x_noisy
.
dtype
)
mid
=
None
if
self
.
t2i_model
.
xl
==
True
:
mid
=
control_input
[
-
1
:]
control_input
=
control_input
[:
-
1
]
return
self
.
control_merge
(
control_input
,
mid
,
control_prev
,
x_noisy
.
dtype
)
def
copy
(
self
):
def
copy
(
self
):
c
=
T2IAdapter
(
self
.
t2i_model
,
self
.
channels_in
)
c
=
T2IAdapter
(
self
.
t2i_model
,
self
.
channels_in
)
...
@@ -1151,11 +1155,20 @@ def load_t2i_adapter(t2i_data):
...
@@ -1151,11 +1155,20 @@ def load_t2i_adapter(t2i_data):
down_opts
=
list
(
filter
(
lambda
a
:
a
.
endswith
(
"down_opt.op.weight"
),
keys
))
down_opts
=
list
(
filter
(
lambda
a
:
a
.
endswith
(
"down_opt.op.weight"
),
keys
))
if
len
(
down_opts
)
>
0
:
if
len
(
down_opts
)
>
0
:
use_conv
=
True
use_conv
=
True
model_ad
=
adapter
.
Adapter
(
cin
=
cin
,
channels
=
[
channel
,
channel
*
2
,
channel
*
4
,
channel
*
4
][:
4
],
nums_rb
=
2
,
ksize
=
ksize
,
sk
=
True
,
use_conv
=
use_conv
)
xl
=
False
if
cin
==
256
:
xl
=
True
model_ad
=
adapter
.
Adapter
(
cin
=
cin
,
channels
=
[
channel
,
channel
*
2
,
channel
*
4
,
channel
*
4
][:
4
],
nums_rb
=
2
,
ksize
=
ksize
,
sk
=
True
,
use_conv
=
use_conv
,
xl
=
xl
)
else
:
else
:
return
None
return
None
model_ad
.
load_state_dict
(
t2i_data
)
missing
,
unexpected
=
model_ad
.
load_state_dict
(
t2i_data
)
return
T2IAdapter
(
model_ad
,
cin
//
64
)
if
len
(
missing
)
>
0
:
print
(
"t2i missing"
,
missing
)
if
len
(
unexpected
)
>
0
:
print
(
"t2i unexpected"
,
unexpected
)
return
T2IAdapter
(
model_ad
,
model_ad
.
input_channels
)
class
StyleModel
:
class
StyleModel
:
...
...
comfy/t2i_adapter/adapter.py
View file @
85fde89d
...
@@ -101,17 +101,30 @@ class ResnetBlock(nn.Module):
...
@@ -101,17 +101,30 @@ class ResnetBlock(nn.Module):
class
Adapter
(
nn
.
Module
):
class
Adapter
(
nn
.
Module
):
def
__init__
(
self
,
channels
=
[
320
,
640
,
1280
,
1280
],
nums_rb
=
3
,
cin
=
64
,
ksize
=
3
,
sk
=
False
,
use_conv
=
True
):
def
__init__
(
self
,
channels
=
[
320
,
640
,
1280
,
1280
],
nums_rb
=
3
,
cin
=
64
,
ksize
=
3
,
sk
=
False
,
use_conv
=
True
,
xl
=
True
):
super
(
Adapter
,
self
).
__init__
()
super
(
Adapter
,
self
).
__init__
()
self
.
unshuffle
=
nn
.
PixelUnshuffle
(
8
)
unshuffle
=
8
resblock_no_downsample
=
[]
resblock_downsample
=
[
3
,
2
,
1
]
self
.
xl
=
xl
if
self
.
xl
:
unshuffle
=
16
resblock_no_downsample
=
[
1
]
resblock_downsample
=
[
2
]
self
.
input_channels
=
cin
//
(
unshuffle
*
unshuffle
)
self
.
unshuffle
=
nn
.
PixelUnshuffle
(
unshuffle
)
self
.
channels
=
channels
self
.
channels
=
channels
self
.
nums_rb
=
nums_rb
self
.
nums_rb
=
nums_rb
self
.
body
=
[]
self
.
body
=
[]
for
i
in
range
(
len
(
channels
)):
for
i
in
range
(
len
(
channels
)):
for
j
in
range
(
nums_rb
):
for
j
in
range
(
nums_rb
):
if
(
i
!=
0
)
and
(
j
==
0
):
if
(
i
in
resblock_downsample
)
and
(
j
==
0
):
self
.
body
.
append
(
self
.
body
.
append
(
ResnetBlock
(
channels
[
i
-
1
],
channels
[
i
],
down
=
True
,
ksize
=
ksize
,
sk
=
sk
,
use_conv
=
use_conv
))
ResnetBlock
(
channels
[
i
-
1
],
channels
[
i
],
down
=
True
,
ksize
=
ksize
,
sk
=
sk
,
use_conv
=
use_conv
))
elif
(
i
in
resblock_no_downsample
)
and
(
j
==
0
):
self
.
body
.
append
(
ResnetBlock
(
channels
[
i
-
1
],
channels
[
i
],
down
=
False
,
ksize
=
ksize
,
sk
=
sk
,
use_conv
=
use_conv
))
else
:
else
:
self
.
body
.
append
(
self
.
body
.
append
(
ResnetBlock
(
channels
[
i
],
channels
[
i
],
down
=
False
,
ksize
=
ksize
,
sk
=
sk
,
use_conv
=
use_conv
))
ResnetBlock
(
channels
[
i
],
channels
[
i
],
down
=
False
,
ksize
=
ksize
,
sk
=
sk
,
use_conv
=
use_conv
))
...
@@ -128,6 +141,14 @@ class Adapter(nn.Module):
...
@@ -128,6 +141,14 @@ class Adapter(nn.Module):
for
j
in
range
(
self
.
nums_rb
):
for
j
in
range
(
self
.
nums_rb
):
idx
=
i
*
self
.
nums_rb
+
j
idx
=
i
*
self
.
nums_rb
+
j
x
=
self
.
body
[
idx
](
x
)
x
=
self
.
body
[
idx
](
x
)
if
self
.
xl
:
features
.
append
(
None
)
if
i
==
0
:
features
.
append
(
None
)
features
.
append
(
None
)
if
i
==
2
:
features
.
append
(
None
)
else
:
features
.
append
(
None
)
features
.
append
(
None
)
features
.
append
(
None
)
features
.
append
(
None
)
features
.
append
(
x
)
features
.
append
(
x
)
...
@@ -243,10 +264,14 @@ class extractor(nn.Module):
...
@@ -243,10 +264,14 @@ class extractor(nn.Module):
class
Adapter_light
(
nn
.
Module
):
class
Adapter_light
(
nn
.
Module
):
def
__init__
(
self
,
channels
=
[
320
,
640
,
1280
,
1280
],
nums_rb
=
3
,
cin
=
64
):
def
__init__
(
self
,
channels
=
[
320
,
640
,
1280
,
1280
],
nums_rb
=
3
,
cin
=
64
):
super
(
Adapter_light
,
self
).
__init__
()
super
(
Adapter_light
,
self
).
__init__
()
self
.
unshuffle
=
nn
.
PixelUnshuffle
(
8
)
unshuffle
=
8
self
.
unshuffle
=
nn
.
PixelUnshuffle
(
unshuffle
)
self
.
input_channels
=
cin
//
(
unshuffle
*
unshuffle
)
self
.
channels
=
channels
self
.
channels
=
channels
self
.
nums_rb
=
nums_rb
self
.
nums_rb
=
nums_rb
self
.
body
=
[]
self
.
body
=
[]
self
.
xl
=
False
for
i
in
range
(
len
(
channels
)):
for
i
in
range
(
len
(
channels
)):
if
i
==
0
:
if
i
==
0
:
self
.
body
.
append
(
extractor
(
in_c
=
cin
,
inter_c
=
channels
[
i
]
//
4
,
out_c
=
channels
[
i
],
nums_rb
=
nums_rb
,
down
=
False
))
self
.
body
.
append
(
extractor
(
in_c
=
cin
,
inter_c
=
channels
[
i
]
//
4
,
out_c
=
channels
[
i
],
nums_rb
=
nums_rb
,
down
=
False
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment