Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
9d00235b
"...composable_kernel_onnx.git" did not exist on "0f912e205eec6e349060f2203a8eeabc5e7ba075"
Commit
9d00235b
authored
Mar 03, 2023
by
comfyanonymous
Browse files
Update T2I adapter code to latest.
parent
3ddff339
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
147 additions
and
13 deletions
+147
-13
comfy/t2i_adapter/adapter.py
comfy/t2i_adapter/adapter.py
+147
-13
No files found.
comfy/t2i_adapter/adapter.py
View file @
9d00235b
#taken from https://github.com/TencentARC/T2I-Adapter
#taken from https://github.com/TencentARC/T2I-Adapter
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
collections
import
OrderedDict
from
ldm.modules.attention
import
SpatialTransformer
,
BasicTransformerBlock
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
"""
...
@@ -17,6 +16,7 @@ def conv_nd(dims, *args, **kwargs):
...
@@ -17,6 +16,7 @@ def conv_nd(dims, *args, **kwargs):
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
"""
"""
Create a 1D, 2D, or 3D average pooling module.
Create a 1D, 2D, or 3D average pooling module.
...
@@ -29,6 +29,7 @@ def avg_pool_nd(dims, *args, **kwargs):
...
@@ -29,6 +29,7 @@ def avg_pool_nd(dims, *args, **kwargs):
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
class
Downsample
(
nn
.
Module
):
class
Downsample
(
nn
.
Module
):
"""
"""
A downsampling layer with an optional convolution.
A downsampling layer with an optional convolution.
...
@@ -38,7 +39,7 @@ class Downsample(nn.Module):
...
@@ -38,7 +39,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
downsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
...
@@ -61,8 +62,8 @@ class Downsample(nn.Module):
...
@@ -61,8 +62,8 @@ class Downsample(nn.Module):
class
ResnetBlock
(
nn
.
Module
):
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_c
,
out_c
,
down
,
ksize
=
3
,
sk
=
False
,
use_conv
=
True
):
def
__init__
(
self
,
in_c
,
out_c
,
down
,
ksize
=
3
,
sk
=
False
,
use_conv
=
True
):
super
().
__init__
()
super
().
__init__
()
ps
=
ksize
//
2
ps
=
ksize
//
2
if
in_c
!=
out_c
or
sk
==
False
:
if
in_c
!=
out_c
or
sk
==
False
:
self
.
in_conv
=
nn
.
Conv2d
(
in_c
,
out_c
,
ksize
,
1
,
ps
)
self
.
in_conv
=
nn
.
Conv2d
(
in_c
,
out_c
,
ksize
,
1
,
ps
)
else
:
else
:
# print('n_in')
# print('n_in')
...
@@ -70,7 +71,7 @@ class ResnetBlock(nn.Module):
...
@@ -70,7 +71,7 @@ class ResnetBlock(nn.Module):
self
.
block1
=
nn
.
Conv2d
(
out_c
,
out_c
,
3
,
1
,
1
)
self
.
block1
=
nn
.
Conv2d
(
out_c
,
out_c
,
3
,
1
,
1
)
self
.
act
=
nn
.
ReLU
()
self
.
act
=
nn
.
ReLU
()
self
.
block2
=
nn
.
Conv2d
(
out_c
,
out_c
,
ksize
,
1
,
ps
)
self
.
block2
=
nn
.
Conv2d
(
out_c
,
out_c
,
ksize
,
1
,
ps
)
if
sk
==
False
:
if
sk
==
False
:
self
.
skep
=
nn
.
Conv2d
(
in_c
,
out_c
,
ksize
,
1
,
ps
)
self
.
skep
=
nn
.
Conv2d
(
in_c
,
out_c
,
ksize
,
1
,
ps
)
else
:
else
:
self
.
skep
=
None
self
.
skep
=
None
...
@@ -82,7 +83,7 @@ class ResnetBlock(nn.Module):
...
@@ -82,7 +83,7 @@ class ResnetBlock(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
self
.
down
==
True
:
if
self
.
down
==
True
:
x
=
self
.
down_opt
(
x
)
x
=
self
.
down_opt
(
x
)
if
self
.
in_conv
is
not
None
:
# edit
if
self
.
in_conv
is
not
None
:
# edit
x
=
self
.
in_conv
(
x
)
x
=
self
.
in_conv
(
x
)
h
=
self
.
block1
(
x
)
h
=
self
.
block1
(
x
)
...
@@ -103,12 +104,14 @@ class Adapter(nn.Module):
...
@@ -103,12 +104,14 @@ class Adapter(nn.Module):
self
.
body
=
[]
self
.
body
=
[]
for
i
in
range
(
len
(
channels
)):
for
i
in
range
(
len
(
channels
)):
for
j
in
range
(
nums_rb
):
for
j
in
range
(
nums_rb
):
if
(
i
!=
0
)
and
(
j
==
0
):
if
(
i
!=
0
)
and
(
j
==
0
):
self
.
body
.
append
(
ResnetBlock
(
channels
[
i
-
1
],
channels
[
i
],
down
=
True
,
ksize
=
ksize
,
sk
=
sk
,
use_conv
=
use_conv
))
self
.
body
.
append
(
ResnetBlock
(
channels
[
i
-
1
],
channels
[
i
],
down
=
True
,
ksize
=
ksize
,
sk
=
sk
,
use_conv
=
use_conv
))
else
:
else
:
self
.
body
.
append
(
ResnetBlock
(
channels
[
i
],
channels
[
i
],
down
=
False
,
ksize
=
ksize
,
sk
=
sk
,
use_conv
=
use_conv
))
self
.
body
.
append
(
ResnetBlock
(
channels
[
i
],
channels
[
i
],
down
=
False
,
ksize
=
ksize
,
sk
=
sk
,
use_conv
=
use_conv
))
self
.
body
=
nn
.
ModuleList
(
self
.
body
)
self
.
body
=
nn
.
ModuleList
(
self
.
body
)
self
.
conv_in
=
nn
.
Conv2d
(
cin
,
channels
[
0
],
3
,
1
,
1
)
self
.
conv_in
=
nn
.
Conv2d
(
cin
,
channels
[
0
],
3
,
1
,
1
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
# unshuffle
# unshuffle
...
@@ -118,8 +121,139 @@ class Adapter(nn.Module):
...
@@ -118,8 +121,139 @@ class Adapter(nn.Module):
x
=
self
.
conv_in
(
x
)
x
=
self
.
conv_in
(
x
)
for
i
in
range
(
len
(
self
.
channels
)):
for
i
in
range
(
len
(
self
.
channels
)):
for
j
in
range
(
self
.
nums_rb
):
for
j
in
range
(
self
.
nums_rb
):
idx
=
i
*
self
.
nums_rb
+
j
idx
=
i
*
self
.
nums_rb
+
j
x
=
self
.
body
[
idx
](
x
)
x
=
self
.
body
[
idx
](
x
)
features
.
append
(
x
)
features
.
append
(
x
)
return
features
return
features
class
LayerNorm
(
nn
.
LayerNorm
):
"""Subclass torch's LayerNorm to handle fp16."""
def
forward
(
self
,
x
:
torch
.
Tensor
):
orig_type
=
x
.
dtype
ret
=
super
().
forward
(
x
.
type
(
torch
.
float32
))
return
ret
.
type
(
orig_type
)
class
QuickGELU
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
x
*
torch
.
sigmoid
(
1.702
*
x
)
class
ResidualAttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
d_model
:
int
,
n_head
:
int
,
attn_mask
:
torch
.
Tensor
=
None
):
super
().
__init__
()
self
.
attn
=
nn
.
MultiheadAttention
(
d_model
,
n_head
)
self
.
ln_1
=
LayerNorm
(
d_model
)
self
.
mlp
=
nn
.
Sequential
(
OrderedDict
([(
"c_fc"
,
nn
.
Linear
(
d_model
,
d_model
*
4
)),
(
"gelu"
,
QuickGELU
()),
(
"c_proj"
,
nn
.
Linear
(
d_model
*
4
,
d_model
))]))
self
.
ln_2
=
LayerNorm
(
d_model
)
self
.
attn_mask
=
attn_mask
def
attention
(
self
,
x
:
torch
.
Tensor
):
self
.
attn_mask
=
self
.
attn_mask
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
)
if
self
.
attn_mask
is
not
None
else
None
return
self
.
attn
(
x
,
x
,
x
,
need_weights
=
False
,
attn_mask
=
self
.
attn_mask
)[
0
]
def
forward
(
self
,
x
:
torch
.
Tensor
):
x
=
x
+
self
.
attention
(
self
.
ln_1
(
x
))
x
=
x
+
self
.
mlp
(
self
.
ln_2
(
x
))
return
x
class
StyleAdapter
(
nn
.
Module
):
def
__init__
(
self
,
width
=
1024
,
context_dim
=
768
,
num_head
=
8
,
n_layes
=
3
,
num_token
=
4
):
super
().
__init__
()
scale
=
width
**
-
0.5
self
.
transformer_layes
=
nn
.
Sequential
(
*
[
ResidualAttentionBlock
(
width
,
num_head
)
for
_
in
range
(
n_layes
)])
self
.
num_token
=
num_token
self
.
style_embedding
=
nn
.
Parameter
(
torch
.
randn
(
1
,
num_token
,
width
)
*
scale
)
self
.
ln_post
=
LayerNorm
(
width
)
self
.
ln_pre
=
LayerNorm
(
width
)
self
.
proj
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
width
,
context_dim
))
def
forward
(
self
,
x
):
# x shape [N, HW+1, C]
style_embedding
=
self
.
style_embedding
+
torch
.
zeros
(
(
x
.
shape
[
0
],
self
.
num_token
,
self
.
style_embedding
.
shape
[
-
1
]),
device
=
x
.
device
)
x
=
torch
.
cat
([
x
,
style_embedding
],
dim
=
1
)
x
=
self
.
ln_pre
(
x
)
x
=
x
.
permute
(
1
,
0
,
2
)
# NLD -> LND
x
=
self
.
transformer_layes
(
x
)
x
=
x
.
permute
(
1
,
0
,
2
)
# LND -> NLD
x
=
self
.
ln_post
(
x
[:,
-
self
.
num_token
:,
:])
x
=
x
@
self
.
proj
return
x
class
ResnetBlock_light
(
nn
.
Module
):
def
__init__
(
self
,
in_c
):
super
().
__init__
()
self
.
block1
=
nn
.
Conv2d
(
in_c
,
in_c
,
3
,
1
,
1
)
self
.
act
=
nn
.
ReLU
()
self
.
block2
=
nn
.
Conv2d
(
in_c
,
in_c
,
3
,
1
,
1
)
def
forward
(
self
,
x
):
h
=
self
.
block1
(
x
)
h
=
self
.
act
(
h
)
h
=
self
.
block2
(
h
)
return
h
+
x
class
extractor
(
nn
.
Module
):
def
__init__
(
self
,
in_c
,
inter_c
,
out_c
,
nums_rb
,
down
=
False
):
super
().
__init__
()
self
.
in_conv
=
nn
.
Conv2d
(
in_c
,
inter_c
,
1
,
1
,
0
)
self
.
body
=
[]
for
_
in
range
(
nums_rb
):
self
.
body
.
append
(
ResnetBlock_light
(
inter_c
))
self
.
body
=
nn
.
Sequential
(
*
self
.
body
)
self
.
out_conv
=
nn
.
Conv2d
(
inter_c
,
out_c
,
1
,
1
,
0
)
self
.
down
=
down
if
self
.
down
==
True
:
self
.
down_opt
=
Downsample
(
in_c
,
use_conv
=
False
)
def
forward
(
self
,
x
):
if
self
.
down
==
True
:
x
=
self
.
down_opt
(
x
)
x
=
self
.
in_conv
(
x
)
x
=
self
.
body
(
x
)
x
=
self
.
out_conv
(
x
)
return
x
class
Adapter_light
(
nn
.
Module
):
def
__init__
(
self
,
channels
=
[
320
,
640
,
1280
,
1280
],
nums_rb
=
3
,
cin
=
64
):
super
(
Adapter_light
,
self
).
__init__
()
self
.
unshuffle
=
nn
.
PixelUnshuffle
(
8
)
self
.
channels
=
channels
self
.
nums_rb
=
nums_rb
self
.
body
=
[]
for
i
in
range
(
len
(
channels
)):
if
i
==
0
:
self
.
body
.
append
(
extractor
(
in_c
=
cin
,
inter_c
=
channels
[
i
]
//
4
,
out_c
=
channels
[
i
],
nums_rb
=
nums_rb
,
down
=
False
))
else
:
self
.
body
.
append
(
extractor
(
in_c
=
channels
[
i
-
1
],
inter_c
=
channels
[
i
]
//
4
,
out_c
=
channels
[
i
],
nums_rb
=
nums_rb
,
down
=
True
))
self
.
body
=
nn
.
ModuleList
(
self
.
body
)
def
forward
(
self
,
x
):
# unshuffle
x
=
self
.
unshuffle
(
x
)
# extract features
features
=
[]
for
i
in
range
(
len
(
self
.
channels
)):
x
=
self
.
body
[
i
](
x
)
features
.
append
(
x
)
return
features
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment