Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
2b139390
Commit
2b139390
authored
Jul 30, 2023
by
comfyanonymous
Browse files
Remove some useless code.
parent
95d796fc
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
241 deletions
+8
-241
comfy/cldm/cldm.py
comfy/cldm/cldm.py
+4
-15
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+4
-226
No files found.
comfy/cldm/cldm.py
View file @
2b139390
...
...
@@ -13,7 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
)
from
..ldm.modules.attention
import
SpatialTransformer
from
..ldm.modules.diffusionmodules.openaimodel
import
UNetModel
,
TimestepEmbedSequential
,
ResBlock
,
Downsample
,
AttentionBlock
from
..ldm.modules.diffusionmodules.openaimodel
import
UNetModel
,
TimestepEmbedSequential
,
ResBlock
,
Downsample
from
..ldm.util
import
exists
...
...
@@ -57,6 +57,7 @@ class ControlNet(nn.Module):
transformer_depth_middle
=
None
,
):
super
().
__init__
()
assert
use_spatial_transformer
==
True
,
"use_spatial_transformer has to be true"
if
use_spatial_transformer
:
assert
context_dim
is
not
None
,
'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
...
...
@@ -200,13 +201,7 @@ class ControlNet(nn.Module):
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
[
level
],
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
...
...
@@ -259,13 +254,7 @@ class ControlNet(nn.Module):
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
# always uses a self-attn
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
2b139390
...
...
@@ -19,45 +19,6 @@ from ..attention import SpatialTransformer
from
comfy.ldm.util
import
exists
# dummy replace
def
convert_module_to_f16
(
x
):
pass
def
convert_module_to_f32
(
x
):
pass
## go
class
AttentionPool2d
(
nn
.
Module
):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def
__init__
(
self
,
spacial_dim
:
int
,
embed_dim
:
int
,
num_heads_channels
:
int
,
output_dim
:
int
=
None
,
):
super
().
__init__
()
self
.
positional_embedding
=
nn
.
Parameter
(
th
.
randn
(
embed_dim
,
spacial_dim
**
2
+
1
)
/
embed_dim
**
0.5
)
self
.
qkv_proj
=
conv_nd
(
1
,
embed_dim
,
3
*
embed_dim
,
1
)
self
.
c_proj
=
conv_nd
(
1
,
embed_dim
,
output_dim
or
embed_dim
,
1
)
self
.
num_heads
=
embed_dim
//
num_heads_channels
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
def
forward
(
self
,
x
):
b
,
c
,
*
_spatial
=
x
.
shape
x
=
x
.
reshape
(
b
,
c
,
-
1
)
# NC(HW)
x
=
th
.
cat
([
x
.
mean
(
dim
=-
1
,
keepdim
=
True
),
x
],
dim
=-
1
)
# NC(HW+1)
x
=
x
+
self
.
positional_embedding
[
None
,
:,
:].
to
(
x
.
dtype
)
# NC(HW+1)
x
=
self
.
qkv_proj
(
x
)
x
=
self
.
attention
(
x
)
x
=
self
.
c_proj
(
x
)
return
x
[:,
:,
0
]
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
...
...
@@ -138,19 +99,6 @@ class Upsample(nn.Module):
x
=
self
.
conv
(
x
)
return
x
class
TransposedUpsample
(
nn
.
Module
):
'Learned 2x upsampling without padding'
def
__init__
(
self
,
channels
,
out_channels
=
None
,
ks
=
5
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
up
=
nn
.
ConvTranspose2d
(
self
.
channels
,
self
.
out_channels
,
kernel_size
=
ks
,
stride
=
2
)
def
forward
(
self
,
x
):
return
self
.
up
(
x
)
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
...
...
@@ -296,142 +244,6 @@ class ResBlock(TimestepBlock):
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
AttentionBlock
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
use_new_attention_order
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
if
use_new_attention_order
:
# split qkv before split heads
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
else
:
# split heads before split qkv
self
.
attention
=
QKVAttentionLegacy
(
self
.
num_heads
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
def
forward
(
self
,
x
):
return
checkpoint
(
self
.
_forward
,
(
x
,),
self
.
parameters
(),
True
)
# TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
#return pt_checkpoint(self._forward, x) # pytorch
def
_forward
(
self
,
x
):
b
,
c
,
*
spatial
=
x
.
shape
x
=
x
.
reshape
(
b
,
c
,
-
1
)
qkv
=
self
.
qkv
(
self
.
norm
(
x
))
h
=
self
.
attention
(
qkv
)
h
=
self
.
proj_out
(
h
)
return
(
x
+
h
).
reshape
(
b
,
c
,
*
spatial
)
def
count_flops_attn
(
model
,
_x
,
y
):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b
,
c
,
*
spatial
=
y
[
0
].
shape
num_spatial
=
int
(
np
.
prod
(
spatial
))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops
=
2
*
b
*
(
num_spatial
**
2
)
*
c
model
.
total_ops
+=
th
.
DoubleTensor
([
matmul_ops
])
class
QKVAttentionLegacy
(
nn
.
Module
):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
th
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
th
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
th
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
class
QKVAttention
(
nn
.
Module
):
"""
A module which performs QKV attention and splits in a different order.
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
th
.
einsum
(
"bct,bcs->bts"
,
(
q
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
(
k
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
)
# More stable with f16 than dividing afterwards
weight
=
th
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
th
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
.
reshape
(
bs
*
self
.
n_heads
,
ch
,
length
))
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
class
Timestep
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
...
...
@@ -507,6 +319,7 @@ class UNetModel(nn.Module):
device
=
None
,
):
super
().
__init__
()
assert
use_spatial_transformer
==
True
,
"use_spatial_transformer has to be true"
if
use_spatial_transformer
:
assert
context_dim
is
not
None
,
'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
...
...
@@ -631,14 +444,7 @@ class UNetModel(nn.Module):
disabled_sa
=
False
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
layers
.
append
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
[
level
],
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
...
...
@@ -693,13 +499,7 @@ class UNetModel(nn.Module):
dtype
=
self
.
dtype
,
device
=
device
,
),
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
# always uses a self-attn
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
...
...
@@ -751,13 +551,7 @@ class UNetModel(nn.Module):
if
not
exists
(
num_attention_blocks
)
or
i
<
num_attention_blocks
[
level
]:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads_upsample
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
[
level
],
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
...
...
@@ -797,22 +591,6 @@ class UNetModel(nn.Module):
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
def
convert_to_fp16
(
self
):
"""
Convert the torso of the model to float16.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f16
)
self
.
middle_block
.
apply
(
convert_module_to_f16
)
self
.
output_blocks
.
apply
(
convert_module_to_f16
)
def
convert_to_fp32
(
self
):
"""
Convert the torso of the model to float32.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
control
=
None
,
transformer_options
=
{},
**
kwargs
):
"""
Apply the model to an input batch.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment