Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
871cc20e
Commit
871cc20e
authored
Nov 23, 2023
by
comfyanonymous
Browse files
Support SVD img2vid model.
parent
022033a0
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1030 additions
and
100 deletions
+1030
-100
comfy/cldm/cldm.py
comfy/cldm/cldm.py
+1
-0
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+236
-35
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+292
-56
comfy/ldm/modules/diffusionmodules/util.py
comfy/ldm/modules/diffusionmodules/util.py
+68
-1
comfy/ldm/modules/temporal_ae.py
comfy/ldm/modules/temporal_ae.py
+244
-0
comfy/model_base.py
comfy/model_base.py
+53
-3
comfy/model_detection.py
comfy/model_detection.py
+17
-1
comfy/model_sampling.py
comfy/model_sampling.py
+45
-1
comfy/sd.py
comfy/sd.py
+9
-1
comfy/supported_models.py
comfy/supported_models.py
+34
-2
comfy_extras/nodes_model_advanced.py
comfy_extras/nodes_model_advanced.py
+31
-0
No files found.
comfy/cldm/cldm.py
View file @
871cc20e
...
...
@@ -54,6 +54,7 @@ class ControlNet(nn.Module):
transformer_depth_output
=
None
,
device
=
None
,
operations
=
comfy
.
ops
,
**
kwargs
,
):
super
().
__init__
()
assert
use_spatial_transformer
==
True
,
"use_spatial_transformer has to be true"
...
...
comfy/ldm/modules/attention.py
View file @
871cc20e
...
...
@@ -5,8 +5,10 @@ import torch.nn.functional as F
from
torch
import
nn
,
einsum
from
einops
import
rearrange
,
repeat
from
typing
import
Optional
,
Any
from
functools
import
partial
from
.diffusionmodules.util
import
checkpoint
from
.diffusionmodules.util
import
checkpoint
,
AlphaBlender
,
timestep_embedding
from
.sub_quadratic_attention
import
efficient_dot_product_attention
from
comfy
import
model_management
...
...
@@ -370,21 +372,45 @@ class CrossAttention(nn.Module):
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
disable_self_attn
=
False
,
dtype
=
None
,
device
=
None
,
operations
=
comfy
.
ops
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
ff_in
=
False
,
inner_dim
=
None
,
disable_self_attn
=
False
,
disable_temporal_crossattention
=
False
,
switch_temporal_ca_to_sa
=
False
,
dtype
=
None
,
device
=
None
,
operations
=
comfy
.
ops
):
super
().
__init__
()
self
.
ff_in
=
ff_in
or
inner_dim
is
not
None
if
inner_dim
is
None
:
inner_dim
=
dim
self
.
is_res
=
inner_dim
==
dim
if
self
.
ff_in
:
self
.
norm_in
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
ff_in
=
FeedForward
(
dim
,
dim_out
=
inner_dim
,
dropout
=
dropout
,
glu
=
gated_ff
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
self
.
disable_self_attn
=
disable_self_attn
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
self
.
attn1
=
CrossAttention
(
query_dim
=
inner_
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
if
self
.
disable_self_attn
else
None
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
# is a self-attention if not self.disable_self_attn
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
self
.
ff
=
FeedForward
(
inner_dim
,
dim_out
=
dim
,
dropout
=
dropout
,
glu
=
gated_ff
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
if
disable_temporal_crossattention
:
if
switch_temporal_ca_to_sa
:
raise
ValueError
else
:
self
.
attn2
=
None
else
:
context_dim_attn2
=
None
if
not
switch_temporal_ca_to_sa
:
context_dim_attn2
=
context_dim
self
.
attn2
=
CrossAttention
(
query_dim
=
inner_dim
,
context_dim
=
context_dim_attn2
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm2
=
nn
.
LayerNorm
(
inner_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm1
=
nn
.
LayerNorm
(
inner_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm3
=
nn
.
LayerNorm
(
inner_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
checkpoint
=
checkpoint
self
.
n_heads
=
n_heads
self
.
d_head
=
d_head
self
.
switch_temporal_ca_to_sa
=
switch_temporal_ca_to_sa
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}):
return
checkpoint
(
self
.
_forward
,
(
x
,
context
,
transformer_options
),
self
.
parameters
(),
self
.
checkpoint
)
...
...
@@ -418,6 +444,12 @@ class BasicTransformerBlock(nn.Module):
else
:
transformer_patches_replace
=
{}
if
self
.
ff_in
:
x_skip
=
x
x
=
self
.
ff_in
(
self
.
norm_in
(
x
))
if
self
.
is_res
:
x
+=
x_skip
n
=
self
.
norm1
(
x
)
if
self
.
disable_self_attn
:
context_attn1
=
context
...
...
@@ -465,8 +497,11 @@ class BasicTransformerBlock(nn.Module):
for
p
in
patch
:
x
=
p
(
x
,
extra_options
)
if
self
.
attn2
is
not
None
:
n
=
self
.
norm2
(
x
)
if
self
.
switch_temporal_ca_to_sa
:
context_attn2
=
n
else
:
context_attn2
=
context
value_attn2
=
None
if
"attn2_patch"
in
transformer_patches
:
...
...
@@ -497,7 +532,12 @@ class BasicTransformerBlock(nn.Module):
n
=
p
(
n
,
extra_options
)
x
+=
n
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
if
self
.
is_res
:
x_skip
=
x
x
=
self
.
ff
(
self
.
norm3
(
x
))
if
self
.
is_res
:
x
+=
x_skip
return
x
...
...
@@ -565,3 +605,164 @@ class SpatialTransformer(nn.Module):
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
class
SpatialVideoTransformer
(
SpatialTransformer
):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
use_linear
=
False
,
context_dim
=
None
,
use_spatial_context
=
False
,
timesteps
=
None
,
merge_strategy
:
str
=
"fixed"
,
merge_factor
:
float
=
0.5
,
time_context_dim
=
None
,
ff_in
=
False
,
checkpoint
=
False
,
time_depth
=
1
,
disable_self_attn
=
False
,
disable_temporal_crossattention
=
False
,
max_time_embed_period
:
int
=
10000
,
dtype
=
None
,
device
=
None
,
operations
=
comfy
.
ops
):
super
().
__init__
(
in_channels
,
n_heads
,
d_head
,
depth
=
depth
,
dropout
=
dropout
,
use_checkpoint
=
checkpoint
,
context_dim
=
context_dim
,
use_linear
=
use_linear
,
disable_self_attn
=
disable_self_attn
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
self
.
time_depth
=
time_depth
self
.
depth
=
depth
self
.
max_time_embed_period
=
max_time_embed_period
time_mix_d_head
=
d_head
n_time_mix_heads
=
n_heads
time_mix_inner_dim
=
int
(
time_mix_d_head
*
n_time_mix_heads
)
inner_dim
=
n_heads
*
d_head
if
use_spatial_context
:
time_context_dim
=
context_dim
self
.
time_stack
=
nn
.
ModuleList
(
[
BasicTransformerBlock
(
inner_dim
,
n_time_mix_heads
,
time_mix_d_head
,
dropout
=
dropout
,
context_dim
=
time_context_dim
,
# timesteps=timesteps,
checkpoint
=
checkpoint
,
ff_in
=
ff_in
,
inner_dim
=
time_mix_inner_dim
,
disable_self_attn
=
disable_self_attn
,
disable_temporal_crossattention
=
disable_temporal_crossattention
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
for
_
in
range
(
self
.
depth
)
]
)
assert
len
(
self
.
time_stack
)
==
len
(
self
.
transformer_blocks
)
self
.
use_spatial_context
=
use_spatial_context
self
.
in_channels
=
in_channels
time_embed_dim
=
self
.
in_channels
*
4
self
.
time_pos_embed
=
nn
.
Sequential
(
operations
.
Linear
(
self
.
in_channels
,
time_embed_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
SiLU
(),
operations
.
Linear
(
time_embed_dim
,
self
.
in_channels
,
dtype
=
dtype
,
device
=
device
),
)
self
.
time_mixer
=
AlphaBlender
(
alpha
=
merge_factor
,
merge_strategy
=
merge_strategy
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
context
:
Optional
[
torch
.
Tensor
]
=
None
,
time_context
:
Optional
[
torch
.
Tensor
]
=
None
,
timesteps
:
Optional
[
int
]
=
None
,
image_only_indicator
:
Optional
[
torch
.
Tensor
]
=
None
,
transformer_options
=
{}
)
->
torch
.
Tensor
:
_
,
_
,
h
,
w
=
x
.
shape
x_in
=
x
spatial_context
=
None
if
exists
(
context
):
spatial_context
=
context
if
self
.
use_spatial_context
:
assert
(
context
.
ndim
==
3
),
f
"n dims of spatial context should be 3 but are
{
context
.
ndim
}
"
if
time_context
is
None
:
time_context
=
context
time_context_first_timestep
=
time_context
[::
timesteps
]
time_context
=
repeat
(
time_context_first_timestep
,
"b ... -> (b n) ..."
,
n
=
h
*
w
)
elif
time_context
is
not
None
and
not
self
.
use_spatial_context
:
time_context
=
repeat
(
time_context
,
"b ... -> (b n) ..."
,
n
=
h
*
w
)
if
time_context
.
ndim
==
2
:
time_context
=
rearrange
(
time_context
,
"b c -> b 1 c"
)
x
=
self
.
norm
(
x
)
if
not
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
x
=
rearrange
(
x
,
"b c h w -> b (h w) c"
)
if
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
num_frames
=
torch
.
arange
(
timesteps
,
device
=
x
.
device
)
num_frames
=
repeat
(
num_frames
,
"t -> b t"
,
b
=
x
.
shape
[
0
]
//
timesteps
)
num_frames
=
rearrange
(
num_frames
,
"b t -> (b t)"
)
t_emb
=
timestep_embedding
(
num_frames
,
self
.
in_channels
,
repeat_only
=
False
,
max_period
=
self
.
max_time_embed_period
).
to
(
x
.
dtype
)
emb
=
self
.
time_pos_embed
(
t_emb
)
emb
=
emb
[:,
None
,
:]
for
it_
,
(
block
,
mix_block
)
in
enumerate
(
zip
(
self
.
transformer_blocks
,
self
.
time_stack
)
):
transformer_options
[
"block_index"
]
=
it_
x
=
block
(
x
,
context
=
spatial_context
,
transformer_options
=
transformer_options
,
)
x_mix
=
x
x_mix
=
x_mix
+
emb
B
,
S
,
C
=
x_mix
.
shape
x_mix
=
rearrange
(
x_mix
,
"(b t) s c -> (b s) t c"
,
t
=
timesteps
)
x_mix
=
mix_block
(
x_mix
,
context
=
time_context
)
#TODO: transformer_options
x_mix
=
rearrange
(
x_mix
,
"(b s) t c -> (b t) s c"
,
s
=
S
,
b
=
B
//
timesteps
,
c
=
C
,
t
=
timesteps
)
x
=
self
.
time_mixer
(
x_spatial
=
x
,
x_temporal
=
x_mix
,
image_only_indicator
=
image_only_indicator
)
if
self
.
use_linear
:
x
=
self
.
proj_out
(
x
)
x
=
rearrange
(
x
,
"b (h w) c -> b c h w"
,
h
=
h
,
w
=
w
)
if
not
self
.
use_linear
:
x
=
self
.
proj_out
(
x
)
out
=
x
+
x_in
return
out
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
871cc20e
...
...
@@ -5,6 +5,8 @@ import numpy as np
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
functools
import
partial
from
.util
import
(
checkpoint
,
...
...
@@ -12,8 +14,9 @@ from .util import (
zero_module
,
normalization
,
timestep_embedding
,
AlphaBlender
,
)
from
..attention
import
SpatialTransformer
from
..attention
import
SpatialTransformer
,
SpatialVideoTransformer
,
default
from
comfy.ldm.util
import
exists
import
comfy.ops
...
...
@@ -29,10 +32,15 @@ class TimestepBlock(nn.Module):
"""
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
def
forward_timestep_embed
(
ts
,
x
,
emb
,
context
=
None
,
transformer_options
=
{},
output_shape
=
None
):
def
forward_timestep_embed
(
ts
,
x
,
emb
,
context
=
None
,
transformer_options
=
{},
output_shape
=
None
,
time_context
=
None
,
num_video_frames
=
None
,
image_only_indicator
=
None
):
for
layer
in
ts
:
if
isinstance
(
layer
,
TimestepBlock
):
if
isinstance
(
layer
,
VideoResBlock
):
x
=
layer
(
x
,
emb
,
num_video_frames
,
image_only_indicator
)
elif
isinstance
(
layer
,
TimestepBlock
):
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
SpatialVideoTransformer
):
x
=
layer
(
x
,
context
,
time_context
,
num_video_frames
,
image_only_indicator
,
transformer_options
)
transformer_options
[
"current_index"
]
+=
1
elif
isinstance
(
layer
,
SpatialTransformer
):
x
=
layer
(
x
,
context
,
transformer_options
)
if
"current_index"
in
transformer_options
:
...
...
@@ -145,6 +153,9 @@ class ResBlock(TimestepBlock):
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
kernel_size
=
3
,
exchange_temb_dims
=
False
,
skip_t_emb
=
False
,
dtype
=
None
,
device
=
None
,
operations
=
comfy
.
ops
...
...
@@ -157,11 +168,17 @@ class ResBlock(TimestepBlock):
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
exchange_temb_dims
=
exchange_temb_dims
if
isinstance
(
kernel_size
,
list
):
padding
=
[
k
//
2
for
k
in
kernel_size
]
else
:
padding
=
kernel_size
//
2
self
.
in_layers
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
channels
,
dtype
=
dtype
,
device
=
device
),
nn
.
SiLU
(),
operations
.
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
),
operations
.
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
kernel_size
,
padding
=
padding
,
dtype
=
dtype
,
device
=
device
),
)
self
.
updown
=
up
or
down
...
...
@@ -175,6 +192,11 @@ class ResBlock(TimestepBlock):
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
skip_t_emb
=
skip_t_emb
if
self
.
skip_t_emb
:
self
.
emb_layers
=
None
self
.
exchange_temb_dims
=
False
else
:
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
operations
.
Linear
(
...
...
@@ -187,7 +209,7 @@ class ResBlock(TimestepBlock):
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
operations
.
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
)
operations
.
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
kernel_size
,
padding
=
padding
,
dtype
=
dtype
,
device
=
device
)
),
)
...
...
@@ -195,7 +217,7 @@ class ResBlock(TimestepBlock):
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
operations
.
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
dims
,
channels
,
self
.
out_channels
,
kernel_size
,
padding
=
padding
,
dtype
=
dtype
,
device
=
device
)
else
:
self
.
skip_connection
=
operations
.
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
,
dtype
=
dtype
,
device
=
device
)
...
...
@@ -221,19 +243,110 @@ class ResBlock(TimestepBlock):
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
None
if
not
self
.
skip_t_emb
:
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
h
=
out_norm
(
h
)
if
emb_out
is
not
None
:
scale
,
shift
=
th
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
*=
(
1
+
scale
)
h
+=
shift
h
=
out_rest
(
h
)
else
:
if
emb_out
is
not
None
:
if
self
.
exchange_temb_dims
:
emb_out
=
rearrange
(
emb_out
,
"b t c ... -> b c t ..."
)
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
VideoResBlock
(
ResBlock
):
def
__init__
(
self
,
channels
:
int
,
emb_channels
:
int
,
dropout
:
float
,
video_kernel_size
=
3
,
merge_strategy
:
str
=
"fixed"
,
merge_factor
:
float
=
0.5
,
out_channels
=
None
,
use_conv
:
bool
=
False
,
use_scale_shift_norm
:
bool
=
False
,
dims
:
int
=
2
,
use_checkpoint
:
bool
=
False
,
up
:
bool
=
False
,
down
:
bool
=
False
,
dtype
=
None
,
device
=
None
,
operations
=
comfy
.
ops
):
super
().
__init__
(
channels
,
emb_channels
,
dropout
,
out_channels
=
out_channels
,
use_conv
=
use_conv
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
up
=
up
,
down
=
down
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
self
.
time_stack
=
ResBlock
(
default
(
out_channels
,
channels
),
emb_channels
,
dropout
=
dropout
,
dims
=
3
,
out_channels
=
default
(
out_channels
,
channels
),
use_scale_shift_norm
=
False
,
use_conv
=
False
,
up
=
False
,
down
=
False
,
kernel_size
=
video_kernel_size
,
use_checkpoint
=
use_checkpoint
,
exchange_temb_dims
=
True
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
self
.
time_mixer
=
AlphaBlender
(
alpha
=
merge_factor
,
merge_strategy
=
merge_strategy
,
rearrange_pattern
=
"b t -> b 1 t 1 1"
,
)
def
forward
(
self
,
x
:
th
.
Tensor
,
emb
:
th
.
Tensor
,
num_video_frames
:
int
,
image_only_indicator
=
None
,
)
->
th
.
Tensor
:
x
=
super
().
forward
(
x
,
emb
)
x_mix
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
t
=
num_video_frames
)
x
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
t
=
num_video_frames
)
x
=
self
.
time_stack
(
x
,
rearrange
(
emb
,
"(b t) ... -> b t ..."
,
t
=
num_video_frames
)
)
x
=
self
.
time_mixer
(
x_spatial
=
x_mix
,
x_temporal
=
x
,
image_only_indicator
=
image_only_indicator
)
x
=
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
return
x
class
Timestep
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
...
...
@@ -310,6 +423,16 @@ class UNetModel(nn.Module):
adm_in_channels
=
None
,
transformer_depth_middle
=
None
,
transformer_depth_output
=
None
,
use_temporal_resblock
=
False
,
use_temporal_attention
=
False
,
time_context_dim
=
None
,
extra_ff_mix_layer
=
False
,
use_spatial_context
=
False
,
merge_strategy
=
None
,
merge_factor
=
0.0
,
video_kernel_size
=
None
,
disable_temporal_crossattention
=
False
,
max_ddpm_temb_period
=
10000
,
device
=
None
,
operations
=
comfy
.
ops
,
):
...
...
@@ -364,8 +487,12 @@ class UNetModel(nn.Module):
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
self
.
use_temporal_resblocks
=
use_temporal_resblock
self
.
predict_codebook_ids
=
n_embed
is
not
None
self
.
default_num_video_frames
=
None
self
.
default_image_only_indicator
=
None
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
nn
.
Sequential
(
operations
.
Linear
(
model_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
,
device
=
device
),
...
...
@@ -402,13 +529,104 @@ class UNetModel(nn.Module):
input_block_chans
=
[
model_channels
]
ch
=
model_channels
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
nr
in
range
(
self
.
num_res_blocks
[
level
]):
layers
=
[
ResBlock
(
def
get_attention_layer
(
ch
,
num_heads
,
dim_head
,
depth
=
1
,
context_dim
=
None
,
use_checkpoint
=
False
,
disable_self_attn
=
False
,
):
if
use_temporal_attention
:
return
SpatialVideoTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
depth
,
context_dim
=
context_dim
,
time_context_dim
=
time_context_dim
,
dropout
=
dropout
,
ff_in
=
extra_ff_mix_layer
,
use_spatial_context
=
use_spatial_context
,
merge_strategy
=
merge_strategy
,
merge_factor
=
merge_factor
,
checkpoint
=
use_checkpoint
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disable_self_attn
,
disable_temporal_crossattention
=
disable_temporal_crossattention
,
max_time_embed_period
=
max_ddpm_temb_period
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
else
:
return
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
def
get_resblock
(
merge_factor
,
merge_strategy
,
video_kernel_size
,
ch
,
time_embed_dim
,
dropout
,
out_channels
,
dims
,
use_checkpoint
,
use_scale_shift_norm
,
down
=
False
,
up
=
False
,
dtype
=
None
,
device
=
None
,
operations
=
comfy
.
ops
):
if
self
.
use_temporal_resblocks
:
return
VideoResBlock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
channels
=
ch
,
emb_channels
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
out_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
down
,
up
=
up
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
else
:
return
ResBlock
(
channels
=
ch
,
emb_channels
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
out_channels
,
use_checkpoint
=
use_checkpoint
,
dims
=
dims
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
down
,
up
=
up
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
for
level
,
mult
in
enumerate
(
channel_mult
):
for
nr
in
range
(
self
.
num_res_blocks
[
level
]):
layers
=
[
get_resblock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
ch
=
ch
,
time_embed_dim
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
...
...
@@ -435,11 +653,9 @@ class UNetModel(nn.Module):
disabled_sa
=
False
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
layers
.
append
(
SpatialTransform
er
(
layers
.
append
(
get_attention_lay
er
(
ch
,
num_heads
,
dim_head
,
depth
=
num_transformers
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
disable_self_attn
=
disabled_sa
,
use_checkpoint
=
use_checkpoint
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
...
...
@@ -448,10 +664,13 @@ class UNetModel(nn.Module):
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
get_resblock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
ch
=
ch
,
time_embed_dim
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
...
...
@@ -481,10 +700,14 @@ class UNetModel(nn.Module):
#num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
mid_block
=
[
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
get_resblock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
ch
=
ch
,
time_embed_dim
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
None
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
...
...
@@ -493,15 +716,18 @@ class UNetModel(nn.Module):
operations
=
operations
)]
if
transformer_depth_middle
>=
0
:
mid_block
+=
[
SpatialTransform
er
(
# always uses a self-attn
mid_block
+=
[
get_attention_lay
er
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
disable_self_attn
=
disable_middle_self_attn
,
use_checkpoint
=
use_checkpoint
),
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
get_resblock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
ch
=
ch
,
time_embed_dim
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
None
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
...
...
@@ -517,10 +743,13 @@ class UNetModel(nn.Module):
for
i
in
range
(
self
.
num_res_blocks
[
level
]
+
1
):
ich
=
input_block_chans
.
pop
()
layers
=
[
ResBlock
(
ch
+
ich
,
time_embed_dim
,
dropout
,
get_resblock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
ch
=
ch
+
ich
,
time_embed_dim
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
model_channels
*
mult
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
...
...
@@ -548,19 +777,21 @@ class UNetModel(nn.Module):
if
not
exists
(
num_attention_blocks
)
or
i
<
num_attention_blocks
[
level
]:
layers
.
append
(
SpatialTransform
er
(
get_attention_lay
er
(
ch
,
num_heads
,
dim_head
,
depth
=
num_transformers
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
disable_self_attn
=
disabled_sa
,
use_checkpoint
=
use_checkpoint
)
)
if
level
and
i
==
self
.
num_res_blocks
[
level
]:
out_ch
=
ch
layers
.
append
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
get_resblock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
ch
=
ch
,
time_embed_dim
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
...
...
@@ -602,6 +833,10 @@ class UNetModel(nn.Module):
transformer_options
[
"current_index"
]
=
0
transformer_patches
=
transformer_options
.
get
(
"patches"
,
{})
num_video_frames
=
kwargs
.
get
(
"num_video_frames"
,
self
.
default_num_video_frames
)
image_only_indicator
=
kwargs
.
get
(
"image_only_indicator"
,
self
.
default_image_only_indicator
)
time_context
=
kwargs
.
get
(
"time_context"
,
None
)
assert
(
y
is
not
None
)
==
(
self
.
num_classes
is
not
None
),
"must specify y if and only if the model is class-conditional"
...
...
@@ -616,7 +851,7 @@ class UNetModel(nn.Module):
h
=
x
.
type
(
self
.
dtype
)
for
id
,
module
in
enumerate
(
self
.
input_blocks
):
transformer_options
[
"block"
]
=
(
"input"
,
id
)
h
=
forward_timestep_embed
(
module
,
h
,
emb
,
context
,
transformer_options
)
h
=
forward_timestep_embed
(
module
,
h
,
emb
,
context
,
transformer_options
,
time_context
=
time_context
,
num_video_frames
=
num_video_frames
,
image_only_indicator
=
image_only_indicator
)
h
=
apply_control
(
h
,
control
,
'input'
)
if
"input_block_patch"
in
transformer_patches
:
patch
=
transformer_patches
[
"input_block_patch"
]
...
...
@@ -630,9 +865,10 @@ class UNetModel(nn.Module):
h
=
p
(
h
,
transformer_options
)
transformer_options
[
"block"
]
=
(
"middle"
,
0
)
h
=
forward_timestep_embed
(
self
.
middle_block
,
h
,
emb
,
context
,
transformer_options
)
h
=
forward_timestep_embed
(
self
.
middle_block
,
h
,
emb
,
context
,
transformer_options
,
time_context
=
time_context
,
num_video_frames
=
num_video_frames
,
image_only_indicator
=
image_only_indicator
)
h
=
apply_control
(
h
,
control
,
'middle'
)
for
id
,
module
in
enumerate
(
self
.
output_blocks
):
transformer_options
[
"block"
]
=
(
"output"
,
id
)
hsp
=
hs
.
pop
()
...
...
@@ -649,7 +885,7 @@ class UNetModel(nn.Module):
output_shape
=
hs
[
-
1
].
shape
else
:
output_shape
=
None
h
=
forward_timestep_embed
(
module
,
h
,
emb
,
context
,
transformer_options
,
output_shape
)
h
=
forward_timestep_embed
(
module
,
h
,
emb
,
context
,
transformer_options
,
output_shape
,
time_context
=
time_context
,
num_video_frames
=
num_video_frames
,
image_only_indicator
=
image_only_indicator
)
h
=
h
.
type
(
x
.
dtype
)
if
self
.
predict_codebook_ids
:
return
self
.
id_predictor
(
h
)
...
...
comfy/ldm/modules/diffusionmodules/util.py
View file @
871cc20e
...
...
@@ -13,11 +13,78 @@ import math
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
einops
import
repeat
from
einops
import
repeat
,
rearrange
from
comfy.ldm.util
import
instantiate_from_config
import
comfy.ops
class
AlphaBlender
(
nn
.
Module
):
strategies
=
[
"learned"
,
"fixed"
,
"learned_with_images"
]
def
__init__
(
self
,
alpha
:
float
,
merge_strategy
:
str
=
"learned_with_images"
,
rearrange_pattern
:
str
=
"b t -> (b t) 1 1"
,
):
super
().
__init__
()
self
.
merge_strategy
=
merge_strategy
self
.
rearrange_pattern
=
rearrange_pattern
assert
(
merge_strategy
in
self
.
strategies
),
f
"merge_strategy needs to be in
{
self
.
strategies
}
"
if
self
.
merge_strategy
==
"fixed"
:
self
.
register_buffer
(
"mix_factor"
,
torch
.
Tensor
([
alpha
]))
elif
(
self
.
merge_strategy
==
"learned"
or
self
.
merge_strategy
==
"learned_with_images"
):
self
.
register_parameter
(
"mix_factor"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
alpha
]))
)
else
:
raise
ValueError
(
f
"unknown merge strategy
{
self
.
merge_strategy
}
"
)
def
get_alpha
(
self
,
image_only_indicator
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
if
self
.
merge_strategy
==
"fixed"
:
# make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha
=
self
.
mix_factor
elif
self
.
merge_strategy
==
"learned"
:
alpha
=
torch
.
sigmoid
(
self
.
mix_factor
)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif
self
.
merge_strategy
==
"learned_with_images"
:
assert
image_only_indicator
is
not
None
,
"need image_only_indicator ..."
alpha
=
torch
.
where
(
image_only_indicator
.
bool
(),
torch
.
ones
(
1
,
1
,
device
=
image_only_indicator
.
device
),
rearrange
(
torch
.
sigmoid
(
self
.
mix_factor
),
"... -> ... 1"
),
)
alpha
=
rearrange
(
alpha
,
self
.
rearrange_pattern
)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
else
:
raise
NotImplementedError
()
return
alpha
def
forward
(
self
,
x_spatial
,
x_temporal
,
image_only_indicator
=
None
,
)
->
torch
.
Tensor
:
alpha
=
self
.
get_alpha
(
image_only_indicator
)
x
=
(
alpha
.
to
(
x_spatial
.
dtype
)
*
x_spatial
+
(
1.0
-
alpha
).
to
(
x_spatial
.
dtype
)
*
x_temporal
)
return
x
def
make_beta_schedule
(
schedule
,
n_timestep
,
linear_start
=
1e-4
,
linear_end
=
2e-2
,
cosine_s
=
8e-3
):
if
schedule
==
"linear"
:
betas
=
(
...
...
comfy/ldm/modules/temporal_ae.py
0 → 100644
View file @
871cc20e
import
functools
from
typing
import
Callable
,
Iterable
,
Union
import
torch
from
einops
import
rearrange
,
repeat
import
comfy.ops
from
.diffusionmodules.model
import
(
AttnBlock
,
Decoder
,
ResnetBlock
,
)
from
.diffusionmodules.openaimodel
import
ResBlock
,
timestep_embedding
from
.attention
import
BasicTransformerBlock
def
partialclass
(
cls
,
*
args
,
**
kwargs
):
class
NewCls
(
cls
):
__init__
=
functools
.
partialmethod
(
cls
.
__init__
,
*
args
,
**
kwargs
)
return
NewCls
class
VideoResBlock
(
ResnetBlock
):
def
__init__
(
self
,
out_channels
,
*
args
,
dropout
=
0.0
,
video_kernel_size
=
3
,
alpha
=
0.0
,
merge_strategy
=
"learned"
,
**
kwargs
,
):
super
().
__init__
(
out_channels
=
out_channels
,
dropout
=
dropout
,
*
args
,
**
kwargs
)
if
video_kernel_size
is
None
:
video_kernel_size
=
[
3
,
1
,
1
]
self
.
time_stack
=
ResBlock
(
channels
=
out_channels
,
emb_channels
=
0
,
dropout
=
dropout
,
dims
=
3
,
use_scale_shift_norm
=
False
,
use_conv
=
False
,
up
=
False
,
down
=
False
,
kernel_size
=
video_kernel_size
,
use_checkpoint
=
False
,
skip_t_emb
=
True
,
)
self
.
merge_strategy
=
merge_strategy
if
self
.
merge_strategy
==
"fixed"
:
self
.
register_buffer
(
"mix_factor"
,
torch
.
Tensor
([
alpha
]))
elif
self
.
merge_strategy
==
"learned"
:
self
.
register_parameter
(
"mix_factor"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
alpha
]))
)
else
:
raise
ValueError
(
f
"unknown merge strategy
{
self
.
merge_strategy
}
"
)
def
get_alpha
(
self
,
bs
):
if
self
.
merge_strategy
==
"fixed"
:
return
self
.
mix_factor
elif
self
.
merge_strategy
==
"learned"
:
return
torch
.
sigmoid
(
self
.
mix_factor
)
else
:
raise
NotImplementedError
()
def
forward
(
self
,
x
,
temb
,
skip_video
=
False
,
timesteps
=
None
):
b
,
c
,
h
,
w
=
x
.
shape
if
timesteps
is
None
:
timesteps
=
b
x
=
super
().
forward
(
x
,
temb
)
if
not
skip_video
:
x_mix
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
t
=
timesteps
)
x
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
t
=
timesteps
)
x
=
self
.
time_stack
(
x
,
temb
)
alpha
=
self
.
get_alpha
(
bs
=
b
//
timesteps
)
x
=
alpha
*
x
+
(
1.0
-
alpha
)
*
x_mix
x
=
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
return
x
class
AE3DConv
(
torch
.
nn
.
Conv2d
):
def
__init__
(
self
,
in_channels
,
out_channels
,
video_kernel_size
=
3
,
*
args
,
**
kwargs
):
super
().
__init__
(
in_channels
,
out_channels
,
*
args
,
**
kwargs
)
if
isinstance
(
video_kernel_size
,
Iterable
):
padding
=
[
int
(
k
//
2
)
for
k
in
video_kernel_size
]
else
:
padding
=
int
(
video_kernel_size
//
2
)
self
.
time_mix_conv
=
torch
.
nn
.
Conv3d
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
video_kernel_size
,
padding
=
padding
,
)
def
forward
(
self
,
input
,
timesteps
=
None
,
skip_video
=
False
):
if
timesteps
is
None
:
timesteps
=
input
.
shape
[
0
]
x
=
super
().
forward
(
input
)
if
skip_video
:
return
x
x
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
t
=
timesteps
)
x
=
self
.
time_mix_conv
(
x
)
return
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
class
AttnVideoBlock
(
AttnBlock
):
def
__init__
(
self
,
in_channels
:
int
,
alpha
:
float
=
0
,
merge_strategy
:
str
=
"learned"
):
super
().
__init__
(
in_channels
)
# no context, single headed, as in base class
self
.
time_mix_block
=
BasicTransformerBlock
(
dim
=
in_channels
,
n_heads
=
1
,
d_head
=
in_channels
,
checkpoint
=
False
,
ff_in
=
True
,
)
time_embed_dim
=
self
.
in_channels
*
4
self
.
video_time_embed
=
torch
.
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
self
.
in_channels
,
time_embed_dim
),
torch
.
nn
.
SiLU
(),
comfy
.
ops
.
Linear
(
time_embed_dim
,
self
.
in_channels
),
)
self
.
merge_strategy
=
merge_strategy
if
self
.
merge_strategy
==
"fixed"
:
self
.
register_buffer
(
"mix_factor"
,
torch
.
Tensor
([
alpha
]))
elif
self
.
merge_strategy
==
"learned"
:
self
.
register_parameter
(
"mix_factor"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
alpha
]))
)
else
:
raise
ValueError
(
f
"unknown merge strategy
{
self
.
merge_strategy
}
"
)
def
forward
(
self
,
x
,
timesteps
=
None
,
skip_time_block
=
False
):
if
skip_time_block
:
return
super
().
forward
(
x
)
if
timesteps
is
None
:
timesteps
=
x
.
shape
[
0
]
x_in
=
x
x
=
self
.
attention
(
x
)
h
,
w
=
x
.
shape
[
2
:]
x
=
rearrange
(
x
,
"b c h w -> b (h w) c"
)
x_mix
=
x
num_frames
=
torch
.
arange
(
timesteps
,
device
=
x
.
device
)
num_frames
=
repeat
(
num_frames
,
"t -> b t"
,
b
=
x
.
shape
[
0
]
//
timesteps
)
num_frames
=
rearrange
(
num_frames
,
"b t -> (b t)"
)
t_emb
=
timestep_embedding
(
num_frames
,
self
.
in_channels
,
repeat_only
=
False
)
emb
=
self
.
video_time_embed
(
t_emb
)
# b, n_channels
emb
=
emb
[:,
None
,
:]
x_mix
=
x_mix
+
emb
alpha
=
self
.
get_alpha
()
x_mix
=
self
.
time_mix_block
(
x_mix
,
timesteps
=
timesteps
)
x
=
alpha
*
x
+
(
1.0
-
alpha
)
*
x_mix
# alpha merge
x
=
rearrange
(
x
,
"b (h w) c -> b c h w"
,
h
=
h
,
w
=
w
)
x
=
self
.
proj_out
(
x
)
return
x_in
+
x
def
get_alpha
(
self
,
):
if
self
.
merge_strategy
==
"fixed"
:
return
self
.
mix_factor
elif
self
.
merge_strategy
==
"learned"
:
return
torch
.
sigmoid
(
self
.
mix_factor
)
else
:
raise
NotImplementedError
(
f
"unknown merge strategy
{
self
.
merge_strategy
}
"
)
def
make_time_attn
(
in_channels
,
attn_type
=
"vanilla"
,
attn_kwargs
=
None
,
alpha
:
float
=
0
,
merge_strategy
:
str
=
"learned"
,
):
return
partialclass
(
AttnVideoBlock
,
in_channels
,
alpha
=
alpha
,
merge_strategy
=
merge_strategy
)
class
Conv2DWrapper
(
torch
.
nn
.
Conv2d
):
def
forward
(
self
,
input
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
return
super
().
forward
(
input
)
class
VideoDecoder
(
Decoder
):
available_time_modes
=
[
"all"
,
"conv-only"
,
"attn-only"
]
def
__init__
(
self
,
*
args
,
video_kernel_size
:
Union
[
int
,
list
]
=
3
,
alpha
:
float
=
0.0
,
merge_strategy
:
str
=
"learned"
,
time_mode
:
str
=
"conv-only"
,
**
kwargs
,
):
self
.
video_kernel_size
=
video_kernel_size
self
.
alpha
=
alpha
self
.
merge_strategy
=
merge_strategy
self
.
time_mode
=
time_mode
assert
(
self
.
time_mode
in
self
.
available_time_modes
),
f
"time_mode parameter has to be in
{
self
.
available_time_modes
}
"
if
self
.
time_mode
!=
"attn-only"
:
kwargs
[
"conv_out_op"
]
=
partialclass
(
AE3DConv
,
video_kernel_size
=
self
.
video_kernel_size
)
if
self
.
time_mode
not
in
[
"conv-only"
,
"only-last-conv"
]:
kwargs
[
"attn_op"
]
=
partialclass
(
make_time_attn
,
alpha
=
self
.
alpha
,
merge_strategy
=
self
.
merge_strategy
)
if
self
.
time_mode
not
in
[
"attn-only"
,
"only-last-conv"
]:
kwargs
[
"resnet_op"
]
=
partialclass
(
VideoResBlock
,
video_kernel_size
=
self
.
video_kernel_size
,
alpha
=
self
.
alpha
,
merge_strategy
=
self
.
merge_strategy
)
super
().
__init__
(
*
args
,
**
kwargs
)
def
get_last_layer
(
self
,
skip_time_mix
=
False
,
**
kwargs
):
if
self
.
time_mode
==
"attn-only"
:
raise
NotImplementedError
(
"TODO"
)
else
:
return
(
self
.
conv_out
.
time_mix_conv
.
weight
if
not
skip_time_mix
else
self
.
conv_out
.
weight
)
comfy/model_base.py
View file @
871cc20e
...
...
@@ -10,17 +10,22 @@ from . import utils
class
ModelType
(
Enum
):
EPS
=
1
V_PREDICTION
=
2
V_PREDICTION_EDM
=
3
from
comfy.model_sampling
import
EPS
,
V_PREDICTION
,
ModelSamplingDiscrete
from
comfy.model_sampling
import
EPS
,
V_PREDICTION
,
ModelSamplingDiscrete
,
ModelSamplingContinuousEDM
def
model_sampling
(
model_config
,
model_type
):
s
=
ModelSamplingDiscrete
if
model_type
==
ModelType
.
EPS
:
c
=
EPS
elif
model_type
==
ModelType
.
V_PREDICTION
:
c
=
V_PREDICTION
s
=
ModelSamplingDiscrete
elif
model_type
==
ModelType
.
V_PREDICTION_EDM
:
c
=
V_PREDICTION
s
=
ModelSamplingContinuousEDM
class
ModelSampling
(
s
,
c
):
pass
...
...
@@ -262,3 +267,48 @@ class SDXL(BaseModel):
out
.
append
(
self
.
embedder
(
torch
.
Tensor
([
target_width
])))
flat
=
torch
.
flatten
(
torch
.
cat
(
out
)).
unsqueeze
(
dim
=
0
).
repeat
(
clip_pooled
.
shape
[
0
],
1
)
return
torch
.
cat
((
clip_pooled
.
to
(
flat
.
device
),
flat
),
dim
=
1
)
class
SVD_img2vid
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
V_PREDICTION_EDM
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
fps_id
=
kwargs
.
get
(
"fps"
,
6
)
-
1
motion_bucket_id
=
kwargs
.
get
(
"motion_bucket_id"
,
127
)
augmentation
=
kwargs
.
get
(
"augmentation_level"
,
0
)
out
=
[]
out
.
append
(
self
.
embedder
(
torch
.
Tensor
([
fps_id
])))
out
.
append
(
self
.
embedder
(
torch
.
Tensor
([
motion_bucket_id
])))
out
.
append
(
self
.
embedder
(
torch
.
Tensor
([
augmentation
])))
flat
=
torch
.
flatten
(
torch
.
cat
(
out
)).
unsqueeze
(
dim
=
0
)
return
flat
def
extra_conds
(
self
,
**
kwargs
):
out
=
{}
adm
=
self
.
encode_adm
(
**
kwargs
)
if
adm
is
not
None
:
out
[
'y'
]
=
comfy
.
conds
.
CONDRegular
(
adm
)
latent_image
=
kwargs
.
get
(
"concat_latent_image"
,
None
)
noise
=
kwargs
.
get
(
"noise"
,
None
)
device
=
kwargs
[
"device"
]
if
latent_image
is
None
:
latent_image
=
torch
.
zeros_like
(
noise
)
if
latent_image
.
shape
[
1
:]
!=
noise
.
shape
[
1
:]:
latent_image
=
utils
.
common_upscale
(
latent_image
,
noise
.
shape
[
-
1
],
noise
.
shape
[
-
2
],
"bilinear"
,
"center"
)
latent_image
=
utils
.
repeat_to_batch_size
(
latent_image
,
noise
.
shape
[
0
])
out
[
'c_concat'
]
=
comfy
.
conds
.
CONDNoiseShape
(
latent_image
)
if
"time_conditioning"
in
kwargs
:
out
[
"time_context"
]
=
comfy
.
conds
.
CONDCrossAttn
(
kwargs
[
"time_conditioning"
])
out
[
'image_only_indicator'
]
=
comfy
.
conds
.
CONDConstant
(
torch
.
zeros
((
1
,),
device
=
device
))
out
[
'num_video_frames'
]
=
comfy
.
conds
.
CONDConstant
(
noise
.
shape
[
0
])
return
out
comfy/model_detection.py
View file @
871cc20e
...
...
@@ -24,7 +24,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
last_transformer_depth
=
count_blocks
(
state_dict_keys
,
transformer_prefix
+
'{}'
)
context_dim
=
state_dict
[
'{}0.attn2.to_k.weight'
.
format
(
transformer_prefix
)].
shape
[
1
]
use_linear_in_transformer
=
len
(
state_dict
[
'{}1.proj_in.weight'
.
format
(
prefix
)].
shape
)
==
2
return
last_transformer_depth
,
context_dim
,
use_linear_in_transformer
time_stack
=
'{}1.time_stack.0.attn1.to_q.weight'
.
format
(
prefix
)
in
state_dict
or
'{}1.time_mix_blocks.0.attn1.to_q.weight'
.
format
(
prefix
)
in
state_dict
return
last_transformer_depth
,
context_dim
,
use_linear_in_transformer
,
time_stack
return
None
def
detect_unet_config
(
state_dict
,
key_prefix
,
dtype
):
...
...
@@ -57,6 +58,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
context_dim
=
None
use_linear_in_transformer
=
False
video_model
=
False
current_res
=
1
count
=
0
...
...
@@ -99,6 +101,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
if
context_dim
is
None
:
context_dim
=
out
[
1
]
use_linear_in_transformer
=
out
[
2
]
video_model
=
out
[
3
]
else
:
transformer_depth
.
append
(
0
)
...
...
@@ -127,6 +130,19 @@ def detect_unet_config(state_dict, key_prefix, dtype):
unet_config
[
"transformer_depth_middle"
]
=
transformer_depth_middle
unet_config
[
'use_linear_in_transformer'
]
=
use_linear_in_transformer
unet_config
[
"context_dim"
]
=
context_dim
if
video_model
:
unet_config
[
"extra_ff_mix_layer"
]
=
True
unet_config
[
"use_spatial_context"
]
=
True
unet_config
[
"merge_strategy"
]
=
"learned_with_images"
unet_config
[
"merge_factor"
]
=
0.0
unet_config
[
"video_kernel_size"
]
=
[
3
,
1
,
1
]
unet_config
[
"use_temporal_resblock"
]
=
True
unet_config
[
"use_temporal_attention"
]
=
True
else
:
unet_config
[
"use_temporal_resblock"
]
=
False
unet_config
[
"use_temporal_attention"
]
=
False
return
unet_config
def
model_config_from_unet_config
(
unet_config
):
...
...
comfy/model_sampling.py
View file @
871cc20e
import
torch
import
numpy
as
np
from
comfy.ldm.modules.diffusionmodules.util
import
make_beta_schedule
import
math
class
EPS
:
def
calculate_input
(
self
,
sigma
,
noise
):
...
...
@@ -83,3 +83,47 @@ class ModelSamplingDiscrete(torch.nn.Module):
percent
=
1.0
-
percent
return
self
.
sigma
(
torch
.
tensor
(
percent
*
999.0
)).
item
()
class
ModelSamplingContinuousEDM
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_config
=
None
):
super
().
__init__
()
self
.
sigma_data
=
1.0
if
model_config
is
not
None
:
sampling_settings
=
model_config
.
sampling_settings
else
:
sampling_settings
=
{}
sigma_min
=
sampling_settings
.
get
(
"sigma_min"
,
0.002
)
sigma_max
=
sampling_settings
.
get
(
"sigma_max"
,
120.0
)
self
.
set_sigma_range
(
sigma_min
,
sigma_max
)
def
set_sigma_range
(
self
,
sigma_min
,
sigma_max
):
sigmas
=
torch
.
linspace
(
math
.
log
(
sigma_min
),
math
.
log
(
sigma_max
),
1000
).
exp
()
self
.
register_buffer
(
'sigmas'
,
sigmas
)
#for compatibility with some schedulers
self
.
register_buffer
(
'log_sigmas'
,
sigmas
.
log
())
@
property
def
sigma_min
(
self
):
return
self
.
sigmas
[
0
]
@
property
def
sigma_max
(
self
):
return
self
.
sigmas
[
-
1
]
def
timestep
(
self
,
sigma
):
return
0.25
*
sigma
.
log
()
def
sigma
(
self
,
timestep
):
return
(
timestep
/
0.25
).
exp
()
def
percent_to_sigma
(
self
,
percent
):
if
percent
<=
0.0
:
return
999999999.9
if
percent
>=
1.0
:
return
0.0
percent
=
1.0
-
percent
log_sigma_min
=
math
.
log
(
self
.
sigma_min
)
return
math
.
exp
((
math
.
log
(
self
.
sigma_max
)
-
log_sigma_min
)
*
percent
+
log_sigma_min
)
comfy/sd.py
View file @
871cc20e
...
...
@@ -159,7 +159,15 @@ class VAE:
self
.
memory_used_decode
=
lambda
shape
,
dtype
:
(
2178
*
shape
[
2
]
*
shape
[
3
]
*
64
)
*
model_management
.
dtype_size
(
dtype
)
if
config
is
None
:
if
"taesd_decoder.1.weight"
in
sd
:
if
"decoder.mid.block_1.mix_factor"
in
sd
:
encoder_config
=
{
'double_z'
:
True
,
'z_channels'
:
4
,
'resolution'
:
256
,
'in_channels'
:
3
,
'out_ch'
:
3
,
'ch'
:
128
,
'ch_mult'
:
[
1
,
2
,
4
,
4
],
'num_res_blocks'
:
2
,
'attn_resolutions'
:
[],
'dropout'
:
0.0
}
decoder_config
=
encoder_config
.
copy
()
decoder_config
[
"video_kernel_size"
]
=
[
3
,
1
,
1
]
decoder_config
[
"alpha"
]
=
0.0
self
.
first_stage_model
=
AutoencodingEngine
(
regularizer_config
=
{
'target'
:
"comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"
},
encoder_config
=
{
'target'
:
"comfy.ldm.modules.diffusionmodules.model.Encoder"
,
'params'
:
encoder_config
},
decoder_config
=
{
'target'
:
"comfy.ldm.modules.temporal_ae.VideoDecoder"
,
'params'
:
decoder_config
})
elif
"taesd_decoder.1.weight"
in
sd
:
self
.
first_stage_model
=
comfy
.
taesd
.
taesd
.
TAESD
()
else
:
#default SD1.x/SD2.x VAE parameters
...
...
comfy/supported_models.py
View file @
871cc20e
...
...
@@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE):
"model_channels"
:
320
,
"use_linear_in_transformer"
:
False
,
"adm_in_channels"
:
None
,
"use_temporal_attention"
:
False
,
}
unet_extra_config
=
{
...
...
@@ -56,6 +57,7 @@ class SD20(supported_models_base.BASE):
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"adm_in_channels"
:
None
,
"use_temporal_attention"
:
False
,
}
latent_format
=
latent_formats
.
SD15
...
...
@@ -88,6 +90,7 @@ class SD21UnclipL(SD20):
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"adm_in_channels"
:
1536
,
"use_temporal_attention"
:
False
,
}
clip_vision_prefix
=
"embedder.model.visual."
...
...
@@ -100,6 +103,7 @@ class SD21UnclipH(SD20):
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"adm_in_channels"
:
2048
,
"use_temporal_attention"
:
False
,
}
clip_vision_prefix
=
"embedder.model.visual."
...
...
@@ -112,6 +116,7 @@ class SDXLRefiner(supported_models_base.BASE):
"context_dim"
:
1280
,
"adm_in_channels"
:
2560
,
"transformer_depth"
:
[
0
,
0
,
4
,
4
,
4
,
4
,
0
,
0
],
"use_temporal_attention"
:
False
,
}
latent_format
=
latent_formats
.
SDXL
...
...
@@ -148,7 +153,8 @@ class SDXL(supported_models_base.BASE):
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
10
,
10
],
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
"adm_in_channels"
:
2816
,
"use_temporal_attention"
:
False
,
}
latent_format
=
latent_formats
.
SDXL
...
...
@@ -203,8 +209,34 @@ class SSD1B(SDXL):
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
4
,
4
],
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
"adm_in_channels"
:
2816
,
"use_temporal_attention"
:
False
,
}
class
SVD_img2vid
(
supported_models_base
.
BASE
):
unet_config
=
{
"model_channels"
:
320
,
"in_channels"
:
8
,
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
],
"context_dim"
:
1024
,
"adm_in_channels"
:
768
,
"use_temporal_attention"
:
True
,
"use_temporal_resblock"
:
True
}
clip_vision_prefix
=
"conditioner.embedders.0.open_clip.model.visual."
latent_format
=
latent_formats
.
SD15
sampling_settings
=
{
"sigma_max"
:
700.0
,
"sigma_min"
:
0.002
}
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
out
=
model_base
.
SVD_img2vid
(
self
,
device
=
device
)
return
out
def
clip_target
(
self
):
return
None
models
=
[
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
]
models
+=
[
SVD_img2vid
]
comfy_extras/nodes_model_advanced.py
View file @
871cc20e
...
...
@@ -128,6 +128,36 @@ class ModelSamplingDiscrete:
m
.
add_object_patch
(
"model_sampling"
,
model_sampling
)
return
(
m
,
)
class
ModelSamplingContinuousEDM
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"sampling"
:
([
"v_prediction"
,
"eps"
],),
"sigma_max"
:
(
"FLOAT"
,
{
"default"
:
120.0
,
"min"
:
0.0
,
"max"
:
1000.0
,
"step"
:
0.001
,
"round"
:
False
}),
"sigma_min"
:
(
"FLOAT"
,
{
"default"
:
0.002
,
"min"
:
0.0
,
"max"
:
1000.0
,
"step"
:
0.001
,
"round"
:
False
}),
}}
RETURN_TYPES
=
(
"MODEL"
,)
FUNCTION
=
"patch"
CATEGORY
=
"advanced/model"
def
patch
(
self
,
model
,
sampling
,
sigma_max
,
sigma_min
):
m
=
model
.
clone
()
if
sampling
==
"eps"
:
sampling_type
=
comfy
.
model_sampling
.
EPS
elif
sampling
==
"v_prediction"
:
sampling_type
=
comfy
.
model_sampling
.
V_PREDICTION
class
ModelSamplingAdvanced
(
comfy
.
model_sampling
.
ModelSamplingContinuousEDM
,
sampling_type
):
pass
model_sampling
=
ModelSamplingAdvanced
()
model_sampling
.
set_sigma_range
(
sigma_min
,
sigma_max
)
m
.
add_object_patch
(
"model_sampling"
,
model_sampling
)
return
(
m
,
)
class
RescaleCFG
:
@
classmethod
def
INPUT_TYPES
(
s
):
...
...
@@ -169,5 +199,6 @@ class RescaleCFG:
NODE_CLASS_MAPPINGS
=
{
"ModelSamplingDiscrete"
:
ModelSamplingDiscrete
,
"ModelSamplingContinuousEDM"
:
ModelSamplingContinuousEDM
,
"RescaleCFG"
:
RescaleCFG
,
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment