Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
871cc20e
"src/vscode:/vscode.git/clone" did not exist on "9f2b7574e21f160c874b9368fc735c472fe1b619"
Commit
871cc20e
authored
Nov 23, 2023
by
comfyanonymous
Browse files
Support SVD img2vid model.
parent
022033a0
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1030 additions
and
100 deletions
+1030
-100
comfy/cldm/cldm.py
comfy/cldm/cldm.py
+1
-0
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+236
-35
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+292
-56
comfy/ldm/modules/diffusionmodules/util.py
comfy/ldm/modules/diffusionmodules/util.py
+68
-1
comfy/ldm/modules/temporal_ae.py
comfy/ldm/modules/temporal_ae.py
+244
-0
comfy/model_base.py
comfy/model_base.py
+53
-3
comfy/model_detection.py
comfy/model_detection.py
+17
-1
comfy/model_sampling.py
comfy/model_sampling.py
+45
-1
comfy/sd.py
comfy/sd.py
+9
-1
comfy/supported_models.py
comfy/supported_models.py
+34
-2
comfy_extras/nodes_model_advanced.py
comfy_extras/nodes_model_advanced.py
+31
-0
No files found.
comfy/cldm/cldm.py
View file @
871cc20e
...
...
@@ -54,6 +54,7 @@ class ControlNet(nn.Module):
transformer_depth_output
=
None
,
device
=
None
,
operations
=
comfy
.
ops
,
**
kwargs
,
):
super
().
__init__
()
assert
use_spatial_transformer
==
True
,
"use_spatial_transformer has to be true"
...
...
comfy/ldm/modules/attention.py
View file @
871cc20e
...
...
@@ -5,8 +5,10 @@ import torch.nn.functional as F
from
torch
import
nn
,
einsum
from
einops
import
rearrange
,
repeat
from
typing
import
Optional
,
Any
from
functools
import
partial
from
.diffusionmodules.util
import
checkpoint
from
.diffusionmodules.util
import
checkpoint
,
AlphaBlender
,
timestep_embedding
from
.sub_quadratic_attention
import
efficient_dot_product_attention
from
comfy
import
model_management
...
...
@@ -370,21 +372,45 @@ class CrossAttention(nn.Module):
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
disable_self_attn
=
False
,
dtype
=
None
,
device
=
None
,
operations
=
comfy
.
ops
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
ff_in
=
False
,
inner_dim
=
None
,
disable_self_attn
=
False
,
disable_temporal_crossattention
=
False
,
switch_temporal_ca_to_sa
=
False
,
dtype
=
None
,
device
=
None
,
operations
=
comfy
.
ops
):
super
().
__init__
()
self
.
ff_in
=
ff_in
or
inner_dim
is
not
None
if
inner_dim
is
None
:
inner_dim
=
dim
self
.
is_res
=
inner_dim
==
dim
if
self
.
ff_in
:
self
.
norm_in
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
ff_in
=
FeedForward
(
dim
,
dim_out
=
inner_dim
,
dropout
=
dropout
,
glu
=
gated_ff
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
self
.
disable_self_attn
=
disable_self_attn
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
self
.
attn1
=
CrossAttention
(
query_dim
=
inner_
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
if
self
.
disable_self_attn
else
None
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
# is a self-attention if not self.disable_self_attn
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
self
.
ff
=
FeedForward
(
inner_dim
,
dim_out
=
dim
,
dropout
=
dropout
,
glu
=
gated_ff
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
if
disable_temporal_crossattention
:
if
switch_temporal_ca_to_sa
:
raise
ValueError
else
:
self
.
attn2
=
None
else
:
context_dim_attn2
=
None
if
not
switch_temporal_ca_to_sa
:
context_dim_attn2
=
context_dim
self
.
attn2
=
CrossAttention
(
query_dim
=
inner_dim
,
context_dim
=
context_dim_attn2
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm2
=
nn
.
LayerNorm
(
inner_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm1
=
nn
.
LayerNorm
(
inner_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
norm3
=
nn
.
LayerNorm
(
inner_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
checkpoint
=
checkpoint
self
.
n_heads
=
n_heads
self
.
d_head
=
d_head
self
.
switch_temporal_ca_to_sa
=
switch_temporal_ca_to_sa
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}):
return
checkpoint
(
self
.
_forward
,
(
x
,
context
,
transformer_options
),
self
.
parameters
(),
self
.
checkpoint
)
...
...
@@ -418,6 +444,12 @@ class BasicTransformerBlock(nn.Module):
else
:
transformer_patches_replace
=
{}
if
self
.
ff_in
:
x_skip
=
x
x
=
self
.
ff_in
(
self
.
norm_in
(
x
))
if
self
.
is_res
:
x
+=
x_skip
n
=
self
.
norm1
(
x
)
if
self
.
disable_self_attn
:
context_attn1
=
context
...
...
@@ -465,8 +497,11 @@ class BasicTransformerBlock(nn.Module):
for
p
in
patch
:
x
=
p
(
x
,
extra_options
)
if
self
.
attn2
is
not
None
:
n
=
self
.
norm2
(
x
)
if
self
.
switch_temporal_ca_to_sa
:
context_attn2
=
n
else
:
context_attn2
=
context
value_attn2
=
None
if
"attn2_patch"
in
transformer_patches
:
...
...
@@ -497,7 +532,12 @@ class BasicTransformerBlock(nn.Module):
n
=
p
(
n
,
extra_options
)
x
+=
n
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
if
self
.
is_res
:
x_skip
=
x
x
=
self
.
ff
(
self
.
norm3
(
x
))
if
self
.
is_res
:
x
+=
x_skip
return
x
...
...
@@ -565,3 +605,164 @@ class SpatialTransformer(nn.Module):
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
class
SpatialVideoTransformer
(
SpatialTransformer
):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
use_linear
=
False
,
context_dim
=
None
,
use_spatial_context
=
False
,
timesteps
=
None
,
merge_strategy
:
str
=
"fixed"
,
merge_factor
:
float
=
0.5
,
time_context_dim
=
None
,
ff_in
=
False
,
checkpoint
=
False
,
time_depth
=
1
,
disable_self_attn
=
False
,
disable_temporal_crossattention
=
False
,
max_time_embed_period
:
int
=
10000
,
dtype
=
None
,
device
=
None
,
operations
=
comfy
.
ops
):
super
().
__init__
(
in_channels
,
n_heads
,
d_head
,
depth
=
depth
,
dropout
=
dropout
,
use_checkpoint
=
checkpoint
,
context_dim
=
context_dim
,
use_linear
=
use_linear
,
disable_self_attn
=
disable_self_attn
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
self
.
time_depth
=
time_depth
self
.
depth
=
depth
self
.
max_time_embed_period
=
max_time_embed_period
time_mix_d_head
=
d_head
n_time_mix_heads
=
n_heads
time_mix_inner_dim
=
int
(
time_mix_d_head
*
n_time_mix_heads
)
inner_dim
=
n_heads
*
d_head
if
use_spatial_context
:
time_context_dim
=
context_dim
self
.
time_stack
=
nn
.
ModuleList
(
[
BasicTransformerBlock
(
inner_dim
,
n_time_mix_heads
,
time_mix_d_head
,
dropout
=
dropout
,
context_dim
=
time_context_dim
,
# timesteps=timesteps,
checkpoint
=
checkpoint
,
ff_in
=
ff_in
,
inner_dim
=
time_mix_inner_dim
,
disable_self_attn
=
disable_self_attn
,
disable_temporal_crossattention
=
disable_temporal_crossattention
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
for
_
in
range
(
self
.
depth
)
]
)
assert
len
(
self
.
time_stack
)
==
len
(
self
.
transformer_blocks
)
self
.
use_spatial_context
=
use_spatial_context
self
.
in_channels
=
in_channels
time_embed_dim
=
self
.
in_channels
*
4
self
.
time_pos_embed
=
nn
.
Sequential
(
operations
.
Linear
(
self
.
in_channels
,
time_embed_dim
,
dtype
=
dtype
,
device
=
device
),
nn
.
SiLU
(),
operations
.
Linear
(
time_embed_dim
,
self
.
in_channels
,
dtype
=
dtype
,
device
=
device
),
)
self
.
time_mixer
=
AlphaBlender
(
alpha
=
merge_factor
,
merge_strategy
=
merge_strategy
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
context
:
Optional
[
torch
.
Tensor
]
=
None
,
time_context
:
Optional
[
torch
.
Tensor
]
=
None
,
timesteps
:
Optional
[
int
]
=
None
,
image_only_indicator
:
Optional
[
torch
.
Tensor
]
=
None
,
transformer_options
=
{}
)
->
torch
.
Tensor
:
_
,
_
,
h
,
w
=
x
.
shape
x_in
=
x
spatial_context
=
None
if
exists
(
context
):
spatial_context
=
context
if
self
.
use_spatial_context
:
assert
(
context
.
ndim
==
3
),
f
"n dims of spatial context should be 3 but are
{
context
.
ndim
}
"
if
time_context
is
None
:
time_context
=
context
time_context_first_timestep
=
time_context
[::
timesteps
]
time_context
=
repeat
(
time_context_first_timestep
,
"b ... -> (b n) ..."
,
n
=
h
*
w
)
elif
time_context
is
not
None
and
not
self
.
use_spatial_context
:
time_context
=
repeat
(
time_context
,
"b ... -> (b n) ..."
,
n
=
h
*
w
)
if
time_context
.
ndim
==
2
:
time_context
=
rearrange
(
time_context
,
"b c -> b 1 c"
)
x
=
self
.
norm
(
x
)
if
not
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
x
=
rearrange
(
x
,
"b c h w -> b (h w) c"
)
if
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
num_frames
=
torch
.
arange
(
timesteps
,
device
=
x
.
device
)
num_frames
=
repeat
(
num_frames
,
"t -> b t"
,
b
=
x
.
shape
[
0
]
//
timesteps
)
num_frames
=
rearrange
(
num_frames
,
"b t -> (b t)"
)
t_emb
=
timestep_embedding
(
num_frames
,
self
.
in_channels
,
repeat_only
=
False
,
max_period
=
self
.
max_time_embed_period
).
to
(
x
.
dtype
)
emb
=
self
.
time_pos_embed
(
t_emb
)
emb
=
emb
[:,
None
,
:]
for
it_
,
(
block
,
mix_block
)
in
enumerate
(
zip
(
self
.
transformer_blocks
,
self
.
time_stack
)
):
transformer_options
[
"block_index"
]
=
it_
x
=
block
(
x
,
context
=
spatial_context
,
transformer_options
=
transformer_options
,
)
x_mix
=
x
x_mix
=
x_mix
+
emb
B
,
S
,
C
=
x_mix
.
shape
x_mix
=
rearrange
(
x_mix
,
"(b t) s c -> (b s) t c"
,
t
=
timesteps
)
x_mix
=
mix_block
(
x_mix
,
context
=
time_context
)
#TODO: transformer_options
x_mix
=
rearrange
(
x_mix
,
"(b s) t c -> (b t) s c"
,
s
=
S
,
b
=
B
//
timesteps
,
c
=
C
,
t
=
timesteps
)
x
=
self
.
time_mixer
(
x_spatial
=
x
,
x_temporal
=
x_mix
,
image_only_indicator
=
image_only_indicator
)
if
self
.
use_linear
:
x
=
self
.
proj_out
(
x
)
x
=
rearrange
(
x
,
"b (h w) c -> b c h w"
,
h
=
h
,
w
=
w
)
if
not
self
.
use_linear
:
x
=
self
.
proj_out
(
x
)
out
=
x
+
x_in
return
out
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
871cc20e
...
...
@@ -5,6 +5,8 @@ import numpy as np
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
functools
import
partial
from
.util
import
(
checkpoint
,
...
...
@@ -12,8 +14,9 @@ from .util import (
zero_module
,
normalization
,
timestep_embedding
,
AlphaBlender
,
)
from
..attention
import
SpatialTransformer
from
..attention
import
SpatialTransformer
,
SpatialVideoTransformer
,
default
from
comfy.ldm.util
import
exists
import
comfy.ops
...
...
@@ -29,10 +32,15 @@ class TimestepBlock(nn.Module):
"""
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
def
forward_timestep_embed
(
ts
,
x
,
emb
,
context
=
None
,
transformer_options
=
{},
output_shape
=
None
):
def
forward_timestep_embed
(
ts
,
x
,
emb
,
context
=
None
,
transformer_options
=
{},
output_shape
=
None
,
time_context
=
None
,
num_video_frames
=
None
,
image_only_indicator
=
None
):
for
layer
in
ts
:
if
isinstance
(
layer
,
TimestepBlock
):
if
isinstance
(
layer
,
VideoResBlock
):
x
=
layer
(
x
,
emb
,
num_video_frames
,
image_only_indicator
)
elif
isinstance
(
layer
,
TimestepBlock
):
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
SpatialVideoTransformer
):
x
=
layer
(
x
,
context
,
time_context
,
num_video_frames
,
image_only_indicator
,
transformer_options
)
transformer_options
[
"current_index"
]
+=
1
elif
isinstance
(
layer
,
SpatialTransformer
):
x
=
layer
(
x
,
context
,
transformer_options
)
if
"current_index"
in
transformer_options
:
...
...
@@ -145,6 +153,9 @@ class ResBlock(TimestepBlock):
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
kernel_size
=
3
,
exchange_temb_dims
=
False
,
skip_t_emb
=
False
,
dtype
=
None
,
device
=
None
,
operations
=
comfy
.
ops
...
...
@@ -157,11 +168,17 @@ class ResBlock(TimestepBlock):
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
exchange_temb_dims
=
exchange_temb_dims
if
isinstance
(
kernel_size
,
list
):
padding
=
[
k
//
2
for
k
in
kernel_size
]
else
:
padding
=
kernel_size
//
2
self
.
in_layers
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
channels
,
dtype
=
dtype
,
device
=
device
),
nn
.
SiLU
(),
operations
.
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
),
operations
.
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
kernel_size
,
padding
=
padding
,
dtype
=
dtype
,
device
=
device
),
)
self
.
updown
=
up
or
down
...
...
@@ -175,6 +192,11 @@ class ResBlock(TimestepBlock):
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
skip_t_emb
=
skip_t_emb
if
self
.
skip_t_emb
:
self
.
emb_layers
=
None
self
.
exchange_temb_dims
=
False
else
:
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
operations
.
Linear
(
...
...
@@ -187,7 +209,7 @@ class ResBlock(TimestepBlock):
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
operations
.
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
)
operations
.
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
kernel_size
,
padding
=
padding
,
dtype
=
dtype
,
device
=
device
)
),
)
...
...
@@ -195,7 +217,7 @@ class ResBlock(TimestepBlock):
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
operations
.
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
dims
,
channels
,
self
.
out_channels
,
kernel_size
,
padding
=
padding
,
dtype
=
dtype
,
device
=
device
)
else
:
self
.
skip_connection
=
operations
.
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
,
dtype
=
dtype
,
device
=
device
)
...
...
@@ -221,19 +243,110 @@ class ResBlock(TimestepBlock):
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
None
if
not
self
.
skip_t_emb
:
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
h
=
out_norm
(
h
)
if
emb_out
is
not
None
:
scale
,
shift
=
th
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
*=
(
1
+
scale
)
h
+=
shift
h
=
out_rest
(
h
)
else
:
if
emb_out
is
not
None
:
if
self
.
exchange_temb_dims
:
emb_out
=
rearrange
(
emb_out
,
"b t c ... -> b c t ..."
)
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
VideoResBlock
(
ResBlock
):
def
__init__
(
self
,
channels
:
int
,
emb_channels
:
int
,
dropout
:
float
,
video_kernel_size
=
3
,
merge_strategy
:
str
=
"fixed"
,
merge_factor
:
float
=
0.5
,
out_channels
=
None
,
use_conv
:
bool
=
False
,
use_scale_shift_norm
:
bool
=
False
,
dims
:
int
=
2
,
use_checkpoint
:
bool
=
False
,
up
:
bool
=
False
,
down
:
bool
=
False
,
dtype
=
None
,
device
=
None
,
operations
=
comfy
.
ops
):
super
().
__init__
(
channels
,
emb_channels
,
dropout
,
out_channels
=
out_channels
,
use_conv
=
use_conv
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
up
=
up
,
down
=
down
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
self
.
time_stack
=
ResBlock
(
default
(
out_channels
,
channels
),
emb_channels
,
dropout
=
dropout
,
dims
=
3
,
out_channels
=
default
(
out_channels
,
channels
),
use_scale_shift_norm
=
False
,
use_conv
=
False
,
up
=
False
,
down
=
False
,
kernel_size
=
video_kernel_size
,
use_checkpoint
=
use_checkpoint
,
exchange_temb_dims
=
True
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
self
.
time_mixer
=
AlphaBlender
(
alpha
=
merge_factor
,
merge_strategy
=
merge_strategy
,
rearrange_pattern
=
"b t -> b 1 t 1 1"
,
)
def
forward
(
self
,
x
:
th
.
Tensor
,
emb
:
th
.
Tensor
,
num_video_frames
:
int
,
image_only_indicator
=
None
,
)
->
th
.
Tensor
:
x
=
super
().
forward
(
x
,
emb
)
x_mix
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
t
=
num_video_frames
)
x
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
t
=
num_video_frames
)
x
=
self
.
time_stack
(
x
,
rearrange
(
emb
,
"(b t) ... -> b t ..."
,
t
=
num_video_frames
)
)
x
=
self
.
time_mixer
(
x_spatial
=
x_mix
,
x_temporal
=
x
,
image_only_indicator
=
image_only_indicator
)
x
=
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
return
x
class
Timestep
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
...
...
@@ -310,6 +423,16 @@ class UNetModel(nn.Module):
adm_in_channels
=
None
,
transformer_depth_middle
=
None
,
transformer_depth_output
=
None
,
use_temporal_resblock
=
False
,
use_temporal_attention
=
False
,
time_context_dim
=
None
,
extra_ff_mix_layer
=
False
,
use_spatial_context
=
False
,
merge_strategy
=
None
,
merge_factor
=
0.0
,
video_kernel_size
=
None
,
disable_temporal_crossattention
=
False
,
max_ddpm_temb_period
=
10000
,
device
=
None
,
operations
=
comfy
.
ops
,
):
...
...
@@ -364,8 +487,12 @@ class UNetModel(nn.Module):
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
self
.
use_temporal_resblocks
=
use_temporal_resblock
self
.
predict_codebook_ids
=
n_embed
is
not
None
self
.
default_num_video_frames
=
None
self
.
default_image_only_indicator
=
None
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
nn
.
Sequential
(
operations
.
Linear
(
model_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
,
device
=
device
),
...
...
@@ -402,13 +529,104 @@ class UNetModel(nn.Module):
input_block_chans
=
[
model_channels
]
ch
=
model_channels
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
nr
in
range
(
self
.
num_res_blocks
[
level
]):
layers
=
[
ResBlock
(
def
get_attention_layer
(
ch
,
num_heads
,
dim_head
,
depth
=
1
,
context_dim
=
None
,
use_checkpoint
=
False
,
disable_self_attn
=
False
,
):
if
use_temporal_attention
:
return
SpatialVideoTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
depth
,
context_dim
=
context_dim
,
time_context_dim
=
time_context_dim
,
dropout
=
dropout
,
ff_in
=
extra_ff_mix_layer
,
use_spatial_context
=
use_spatial_context
,
merge_strategy
=
merge_strategy
,
merge_factor
=
merge_factor
,
checkpoint
=
use_checkpoint
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disable_self_attn
,
disable_temporal_crossattention
=
disable_temporal_crossattention
,
max_time_embed_period
=
max_ddpm_temb_period
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
else
:
return
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
def
get_resblock
(
merge_factor
,
merge_strategy
,
video_kernel_size
,
ch
,
time_embed_dim
,
dropout
,
out_channels
,
dims
,
use_checkpoint
,
use_scale_shift_norm
,
down
=
False
,
up
=
False
,
dtype
=
None
,
device
=
None
,
operations
=
comfy
.
ops
):
if
self
.
use_temporal_resblocks
:
return
VideoResBlock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
channels
=
ch
,
emb_channels
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
out_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
down
,
up
=
up
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
else
:
return
ResBlock
(
channels
=
ch
,
emb_channels
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
out_channels
,
use_checkpoint
=
use_checkpoint
,
dims
=
dims
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
down
,
up
=
up
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
for
level
,
mult
in
enumerate
(
channel_mult
):
for
nr
in
range
(
self
.
num_res_blocks
[
level
]):
layers
=
[
get_resblock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
ch
=
ch
,
time_embed_dim
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
...
...
@@ -435,11 +653,9 @@ class UNetModel(nn.Module):
disabled_sa
=
False
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
layers
.
append
(
SpatialTransform
er
(
layers
.
append
(
get_attention_lay
er
(
ch
,
num_heads
,
dim_head
,
depth
=
num_transformers
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
disable_self_attn
=
disabled_sa
,
use_checkpoint
=
use_checkpoint
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
...
...
@@ -448,10 +664,13 @@ class UNetModel(nn.Module):
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
get_resblock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
ch
=
ch
,
time_embed_dim
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
...
...
@@ -481,10 +700,14 @@ class UNetModel(nn.Module):
#num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
mid_block
=
[
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
get_resblock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
ch
=
ch
,
time_embed_dim
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
None
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
...
...
@@ -493,15 +716,18 @@ class UNetModel(nn.Module):
operations
=
operations
)]
if
transformer_depth_middle
>=
0
:
mid_block
+=
[
SpatialTransform
er
(
# always uses a self-attn
mid_block
+=
[
get_attention_lay
er
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
disable_self_attn
=
disable_middle_self_attn
,
use_checkpoint
=
use_checkpoint
),
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
get_resblock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
ch
=
ch
,
time_embed_dim
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
None
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
...
...
@@ -517,10 +743,13 @@ class UNetModel(nn.Module):
for
i
in
range
(
self
.
num_res_blocks
[
level
]
+
1
):
ich
=
input_block_chans
.
pop
()
layers
=
[
ResBlock
(
ch
+
ich
,
time_embed_dim
,
dropout
,
get_resblock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
ch
=
ch
+
ich
,
time_embed_dim
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
model_channels
*
mult
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
...
...
@@ -548,19 +777,21 @@ class UNetModel(nn.Module):
if
not
exists
(
num_attention_blocks
)
or
i
<
num_attention_blocks
[
level
]:
layers
.
append
(
SpatialTransform
er
(
get_attention_lay
er
(
ch
,
num_heads
,
dim_head
,
depth
=
num_transformers
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
disable_self_attn
=
disabled_sa
,
use_checkpoint
=
use_checkpoint
)
)
if
level
and
i
==
self
.
num_res_blocks
[
level
]:
out_ch
=
ch
layers
.
append
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
get_resblock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
ch
=
ch
,
time_embed_dim
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
...
...
@@ -602,6 +833,10 @@ class UNetModel(nn.Module):
transformer_options
[
"current_index"
]
=
0
transformer_patches
=
transformer_options
.
get
(
"patches"
,
{})
num_video_frames
=
kwargs
.
get
(
"num_video_frames"
,
self
.
default_num_video_frames
)
image_only_indicator
=
kwargs
.
get
(
"image_only_indicator"
,
self
.
default_image_only_indicator
)
time_context
=
kwargs
.
get
(
"time_context"
,
None
)
assert
(
y
is
not
None
)
==
(
self
.
num_classes
is
not
None
),
"must specify y if and only if the model is class-conditional"
...
...
@@ -616,7 +851,7 @@ class UNetModel(nn.Module):
h
=
x
.
type
(
self
.
dtype
)
for
id
,
module
in
enumerate
(
self
.
input_blocks
):
transformer_options
[
"block"
]
=
(
"input"
,
id
)
h
=
forward_timestep_embed
(
module
,
h
,
emb
,
context
,
transformer_options
)
h
=
forward_timestep_embed
(
module
,
h
,
emb
,
context
,
transformer_options
,
time_context
=
time_context
,
num_video_frames
=
num_video_frames
,
image_only_indicator
=
image_only_indicator
)
h
=
apply_control
(
h
,
control
,
'input'
)
if
"input_block_patch"
in
transformer_patches
:
patch
=
transformer_patches
[
"input_block_patch"
]
...
...
@@ -630,9 +865,10 @@ class UNetModel(nn.Module):
h
=
p
(
h
,
transformer_options
)
transformer_options
[
"block"
]
=
(
"middle"
,
0
)
h
=
forward_timestep_embed
(
self
.
middle_block
,
h
,
emb
,
context
,
transformer_options
)
h
=
forward_timestep_embed
(
self
.
middle_block
,
h
,
emb
,
context
,
transformer_options
,
time_context
=
time_context
,
num_video_frames
=
num_video_frames
,
image_only_indicator
=
image_only_indicator
)
h
=
apply_control
(
h
,
control
,
'middle'
)
for
id
,
module
in
enumerate
(
self
.
output_blocks
):
transformer_options
[
"block"
]
=
(
"output"
,
id
)
hsp
=
hs
.
pop
()
...
...
@@ -649,7 +885,7 @@ class UNetModel(nn.Module):
output_shape
=
hs
[
-
1
].
shape
else
:
output_shape
=
None
h
=
forward_timestep_embed
(
module
,
h
,
emb
,
context
,
transformer_options
,
output_shape
)
h
=
forward_timestep_embed
(
module
,
h
,
emb
,
context
,
transformer_options
,
output_shape
,
time_context
=
time_context
,
num_video_frames
=
num_video_frames
,
image_only_indicator
=
image_only_indicator
)
h
=
h
.
type
(
x
.
dtype
)
if
self
.
predict_codebook_ids
:
return
self
.
id_predictor
(
h
)
...
...
comfy/ldm/modules/diffusionmodules/util.py
View file @
871cc20e
...
...
@@ -13,11 +13,78 @@ import math
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
einops
import
repeat
from
einops
import
repeat
,
rearrange
from
comfy.ldm.util
import
instantiate_from_config
import
comfy.ops
class
AlphaBlender
(
nn
.
Module
):
strategies
=
[
"learned"
,
"fixed"
,
"learned_with_images"
]
def
__init__
(
self
,
alpha
:
float
,
merge_strategy
:
str
=
"learned_with_images"
,
rearrange_pattern
:
str
=
"b t -> (b t) 1 1"
,
):
super
().
__init__
()
self
.
merge_strategy
=
merge_strategy
self
.
rearrange_pattern
=
rearrange_pattern
assert
(
merge_strategy
in
self
.
strategies
),
f
"merge_strategy needs to be in
{
self
.
strategies
}
"
if
self
.
merge_strategy
==
"fixed"
:
self
.
register_buffer
(
"mix_factor"
,
torch
.
Tensor
([
alpha
]))
elif
(
self
.
merge_strategy
==
"learned"
or
self
.
merge_strategy
==
"learned_with_images"
):
self
.
register_parameter
(
"mix_factor"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
alpha
]))
)
else
:
raise
ValueError
(
f
"unknown merge strategy
{
self
.
merge_strategy
}
"
)
def
get_alpha
(
self
,
image_only_indicator
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
if
self
.
merge_strategy
==
"fixed"
:
# make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha
=
self
.
mix_factor
elif
self
.
merge_strategy
==
"learned"
:
alpha
=
torch
.
sigmoid
(
self
.
mix_factor
)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif
self
.
merge_strategy
==
"learned_with_images"
:
assert
image_only_indicator
is
not
None
,
"need image_only_indicator ..."
alpha
=
torch
.
where
(
image_only_indicator
.
bool
(),
torch
.
ones
(
1
,
1
,
device
=
image_only_indicator
.
device
),
rearrange
(
torch
.
sigmoid
(
self
.
mix_factor
),
"... -> ... 1"
),
)
alpha
=
rearrange
(
alpha
,
self
.
rearrange_pattern
)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
else
:
raise
NotImplementedError
()
return
alpha
def
forward
(
self
,
x_spatial
,
x_temporal
,
image_only_indicator
=
None
,
)
->
torch
.
Tensor
:
alpha
=
self
.
get_alpha
(
image_only_indicator
)
x
=
(
alpha
.
to
(
x_spatial
.
dtype
)
*
x_spatial
+
(
1.0
-
alpha
).
to
(
x_spatial
.
dtype
)
*
x_temporal
)
return
x
def
make_beta_schedule
(
schedule
,
n_timestep
,
linear_start
=
1e-4
,
linear_end
=
2e-2
,
cosine_s
=
8e-3
):
if
schedule
==
"linear"
:
betas
=
(
...
...
comfy/ldm/modules/temporal_ae.py
0 → 100644
View file @
871cc20e
import
functools
from
typing
import
Callable
,
Iterable
,
Union
import
torch
from
einops
import
rearrange
,
repeat
import
comfy.ops
from
.diffusionmodules.model
import
(
AttnBlock
,
Decoder
,
ResnetBlock
,
)
from
.diffusionmodules.openaimodel
import
ResBlock
,
timestep_embedding
from
.attention
import
BasicTransformerBlock
def
partialclass
(
cls
,
*
args
,
**
kwargs
):
class
NewCls
(
cls
):
__init__
=
functools
.
partialmethod
(
cls
.
__init__
,
*
args
,
**
kwargs
)
return
NewCls
class
VideoResBlock
(
ResnetBlock
):
def
__init__
(
self
,
out_channels
,
*
args
,
dropout
=
0.0
,
video_kernel_size
=
3
,
alpha
=
0.0
,
merge_strategy
=
"learned"
,
**
kwargs
,
):
super
().
__init__
(
out_channels
=
out_channels
,
dropout
=
dropout
,
*
args
,
**
kwargs
)
if
video_kernel_size
is
None
:
video_kernel_size
=
[
3
,
1
,
1
]
self
.
time_stack
=
ResBlock
(
channels
=
out_channels
,
emb_channels
=
0
,
dropout
=
dropout
,
dims
=
3
,
use_scale_shift_norm
=
False
,
use_conv
=
False
,
up
=
False
,
down
=
False
,
kernel_size
=
video_kernel_size
,
use_checkpoint
=
False
,
skip_t_emb
=
True
,
)
self
.
merge_strategy
=
merge_strategy
if
self
.
merge_strategy
==
"fixed"
:
self
.
register_buffer
(
"mix_factor"
,
torch
.
Tensor
([
alpha
]))
elif
self
.
merge_strategy
==
"learned"
:
self
.
register_parameter
(
"mix_factor"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
alpha
]))
)
else
:
raise
ValueError
(
f
"unknown merge strategy
{
self
.
merge_strategy
}
"
)
def
get_alpha
(
self
,
bs
):
if
self
.
merge_strategy
==
"fixed"
:
return
self
.
mix_factor
elif
self
.
merge_strategy
==
"learned"
:
return
torch
.
sigmoid
(
self
.
mix_factor
)
else
:
raise
NotImplementedError
()
def
forward
(
self
,
x
,
temb
,
skip_video
=
False
,
timesteps
=
None
):
b
,
c
,
h
,
w
=
x
.
shape
if
timesteps
is
None
:
timesteps
=
b
x
=
super
().
forward
(
x
,
temb
)
if
not
skip_video
:
x_mix
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
t
=
timesteps
)
x
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
t
=
timesteps
)
x
=
self
.
time_stack
(
x
,
temb
)
alpha
=
self
.
get_alpha
(
bs
=
b
//
timesteps
)
x
=
alpha
*
x
+
(
1.0
-
alpha
)
*
x_mix
x
=
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
return
x
class
AE3DConv
(
torch
.
nn
.
Conv2d
):
def
__init__
(
self
,
in_channels
,
out_channels
,
video_kernel_size
=
3
,
*
args
,
**
kwargs
):
super
().
__init__
(
in_channels
,
out_channels
,
*
args
,
**
kwargs
)
if
isinstance
(
video_kernel_size
,
Iterable
):
padding
=
[
int
(
k
//
2
)
for
k
in
video_kernel_size
]
else
:
padding
=
int
(
video_kernel_size
//
2
)
self
.
time_mix_conv
=
torch
.
nn
.
Conv3d
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
video_kernel_size
,
padding
=
padding
,
)
def
forward
(
self
,
input
,
timesteps
=
None
,
skip_video
=
False
):
if
timesteps
is
None
:
timesteps
=
input
.
shape
[
0
]
x
=
super
().
forward
(
input
)
if
skip_video
:
return
x
x
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
t
=
timesteps
)
x
=
self
.
time_mix_conv
(
x
)
return
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
class
AttnVideoBlock
(
AttnBlock
):
def
__init__
(
self
,
in_channels
:
int
,
alpha
:
float
=
0
,
merge_strategy
:
str
=
"learned"
):
super
().
__init__
(
in_channels
)
# no context, single headed, as in base class
self
.
time_mix_block
=
BasicTransformerBlock
(
dim
=
in_channels
,
n_heads
=
1
,
d_head
=
in_channels
,
checkpoint
=
False
,
ff_in
=
True
,
)
time_embed_dim
=
self
.
in_channels
*
4
self
.
video_time_embed
=
torch
.
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
self
.
in_channels
,
time_embed_dim
),
torch
.
nn
.
SiLU
(),
comfy
.
ops
.
Linear
(
time_embed_dim
,
self
.
in_channels
),
)
self
.
merge_strategy
=
merge_strategy
if
self
.
merge_strategy
==
"fixed"
:
self
.
register_buffer
(
"mix_factor"
,
torch
.
Tensor
([
alpha
]))
elif
self
.
merge_strategy
==
"learned"
:
self
.
register_parameter
(
"mix_factor"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
alpha
]))
)
else
:
raise
ValueError
(
f
"unknown merge strategy
{
self
.
merge_strategy
}
"
)
def
forward
(
self
,
x
,
timesteps
=
None
,
skip_time_block
=
False
):
if
skip_time_block
:
return
super
().
forward
(
x
)
if
timesteps
is
None
:
timesteps
=
x
.
shape
[
0
]
x_in
=
x
x
=
self
.
attention
(
x
)
h
,
w
=
x
.
shape
[
2
:]
x
=
rearrange
(
x
,
"b c h w -> b (h w) c"
)
x_mix
=
x
num_frames
=
torch
.
arange
(
timesteps
,
device
=
x
.
device
)
num_frames
=
repeat
(
num_frames
,
"t -> b t"
,
b
=
x
.
shape
[
0
]
//
timesteps
)
num_frames
=
rearrange
(
num_frames
,
"b t -> (b t)"
)
t_emb
=
timestep_embedding
(
num_frames
,
self
.
in_channels
,
repeat_only
=
False
)
emb
=
self
.
video_time_embed
(
t_emb
)
# b, n_channels
emb
=
emb
[:,
None
,
:]
x_mix
=
x_mix
+
emb
alpha
=
self
.
get_alpha
()
x_mix
=
self
.
time_mix_block
(
x_mix
,
timesteps
=
timesteps
)
x
=
alpha
*
x
+
(
1.0
-
alpha
)
*
x_mix
# alpha merge
x
=
rearrange
(
x
,
"b (h w) c -> b c h w"
,
h
=
h
,
w
=
w
)
x
=
self
.
proj_out
(
x
)
return
x_in
+
x
def
get_alpha
(
self
,
):
if
self
.
merge_strategy
==
"fixed"
:
return
self
.
mix_factor
elif
self
.
merge_strategy
==
"learned"
:
return
torch
.
sigmoid
(
self
.
mix_factor
)
else
:
raise
NotImplementedError
(
f
"unknown merge strategy
{
self
.
merge_strategy
}
"
)
def
make_time_attn
(
in_channels
,
attn_type
=
"vanilla"
,
attn_kwargs
=
None
,
alpha
:
float
=
0
,
merge_strategy
:
str
=
"learned"
,
):
return
partialclass
(
AttnVideoBlock
,
in_channels
,
alpha
=
alpha
,
merge_strategy
=
merge_strategy
)
class
Conv2DWrapper
(
torch
.
nn
.
Conv2d
):
def
forward
(
self
,
input
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
return
super
().
forward
(
input
)
class
VideoDecoder
(
Decoder
):
available_time_modes
=
[
"all"
,
"conv-only"
,
"attn-only"
]
def
__init__
(
self
,
*
args
,
video_kernel_size
:
Union
[
int
,
list
]
=
3
,
alpha
:
float
=
0.0
,
merge_strategy
:
str
=
"learned"
,
time_mode
:
str
=
"conv-only"
,
**
kwargs
,
):
self
.
video_kernel_size
=
video_kernel_size
self
.
alpha
=
alpha
self
.
merge_strategy
=
merge_strategy
self
.
time_mode
=
time_mode
assert
(
self
.
time_mode
in
self
.
available_time_modes
),
f
"time_mode parameter has to be in
{
self
.
available_time_modes
}
"
if
self
.
time_mode
!=
"attn-only"
:
kwargs
[
"conv_out_op"
]
=
partialclass
(
AE3DConv
,
video_kernel_size
=
self
.
video_kernel_size
)
if
self
.
time_mode
not
in
[
"conv-only"
,
"only-last-conv"
]:
kwargs
[
"attn_op"
]
=
partialclass
(
make_time_attn
,
alpha
=
self
.
alpha
,
merge_strategy
=
self
.
merge_strategy
)
if
self
.
time_mode
not
in
[
"attn-only"
,
"only-last-conv"
]:
kwargs
[
"resnet_op"
]
=
partialclass
(
VideoResBlock
,
video_kernel_size
=
self
.
video_kernel_size
,
alpha
=
self
.
alpha
,
merge_strategy
=
self
.
merge_strategy
)
super
().
__init__
(
*
args
,
**
kwargs
)
def
get_last_layer
(
self
,
skip_time_mix
=
False
,
**
kwargs
):
if
self
.
time_mode
==
"attn-only"
:
raise
NotImplementedError
(
"TODO"
)
else
:
return
(
self
.
conv_out
.
time_mix_conv
.
weight
if
not
skip_time_mix
else
self
.
conv_out
.
weight
)
comfy/model_base.py
View file @
871cc20e
...
...
@@ -10,17 +10,22 @@ from . import utils
class
ModelType
(
Enum
):
EPS
=
1
V_PREDICTION
=
2
V_PREDICTION_EDM
=
3
from
comfy.model_sampling
import
EPS
,
V_PREDICTION
,
ModelSamplingDiscrete
from
comfy.model_sampling
import
EPS
,
V_PREDICTION
,
ModelSamplingDiscrete
,
ModelSamplingContinuousEDM
def
model_sampling
(
model_config
,
model_type
):
s
=
ModelSamplingDiscrete
if
model_type
==
ModelType
.
EPS
:
c
=
EPS
elif
model_type
==
ModelType
.
V_PREDICTION
:
c
=
V_PREDICTION
s
=
ModelSamplingDiscrete
elif
model_type
==
ModelType
.
V_PREDICTION_EDM
:
c
=
V_PREDICTION
s
=
ModelSamplingContinuousEDM
class
ModelSampling
(
s
,
c
):
pass
...
...
@@ -262,3 +267,48 @@ class SDXL(BaseModel):
out
.
append
(
self
.
embedder
(
torch
.
Tensor
([
target_width
])))
flat
=
torch
.
flatten
(
torch
.
cat
(
out
)).
unsqueeze
(
dim
=
0
).
repeat
(
clip_pooled
.
shape
[
0
],
1
)
return
torch
.
cat
((
clip_pooled
.
to
(
flat
.
device
),
flat
),
dim
=
1
)
class
SVD_img2vid
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
V_PREDICTION_EDM
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
fps_id
=
kwargs
.
get
(
"fps"
,
6
)
-
1
motion_bucket_id
=
kwargs
.
get
(
"motion_bucket_id"
,
127
)
augmentation
=
kwargs
.
get
(
"augmentation_level"
,
0
)
out
=
[]
out
.
append
(
self
.
embedder
(
torch
.
Tensor
([
fps_id
])))
out
.
append
(
self
.
embedder
(
torch
.
Tensor
([
motion_bucket_id
])))
out
.
append
(
self
.
embedder
(
torch
.
Tensor
([
augmentation
])))
flat
=
torch
.
flatten
(
torch
.
cat
(
out
)).
unsqueeze
(
dim
=
0
)
return
flat
def
extra_conds
(
self
,
**
kwargs
):
out
=
{}
adm
=
self
.
encode_adm
(
**
kwargs
)
if
adm
is
not
None
:
out
[
'y'
]
=
comfy
.
conds
.
CONDRegular
(
adm
)
latent_image
=
kwargs
.
get
(
"concat_latent_image"
,
None
)
noise
=
kwargs
.
get
(
"noise"
,
None
)
device
=
kwargs
[
"device"
]
if
latent_image
is
None
:
latent_image
=
torch
.
zeros_like
(
noise
)
if
latent_image
.
shape
[
1
:]
!=
noise
.
shape
[
1
:]:
latent_image
=
utils
.
common_upscale
(
latent_image
,
noise
.
shape
[
-
1
],
noise
.
shape
[
-
2
],
"bilinear"
,
"center"
)
latent_image
=
utils
.
repeat_to_batch_size
(
latent_image
,
noise
.
shape
[
0
])
out
[
'c_concat'
]
=
comfy
.
conds
.
CONDNoiseShape
(
latent_image
)
if
"time_conditioning"
in
kwargs
:
out
[
"time_context"
]
=
comfy
.
conds
.
CONDCrossAttn
(
kwargs
[
"time_conditioning"
])
out
[
'image_only_indicator'
]
=
comfy
.
conds
.
CONDConstant
(
torch
.
zeros
((
1
,),
device
=
device
))
out
[
'num_video_frames'
]
=
comfy
.
conds
.
CONDConstant
(
noise
.
shape
[
0
])
return
out
comfy/model_detection.py
View file @
871cc20e
...
...
@@ -24,7 +24,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
last_transformer_depth
=
count_blocks
(
state_dict_keys
,
transformer_prefix
+
'{}'
)
context_dim
=
state_dict
[
'{}0.attn2.to_k.weight'
.
format
(
transformer_prefix
)].
shape
[
1
]
use_linear_in_transformer
=
len
(
state_dict
[
'{}1.proj_in.weight'
.
format
(
prefix
)].
shape
)
==
2
return
last_transformer_depth
,
context_dim
,
use_linear_in_transformer
time_stack
=
'{}1.time_stack.0.attn1.to_q.weight'
.
format
(
prefix
)
in
state_dict
or
'{}1.time_mix_blocks.0.attn1.to_q.weight'
.
format
(
prefix
)
in
state_dict
return
last_transformer_depth
,
context_dim
,
use_linear_in_transformer
,
time_stack
return
None
def
detect_unet_config
(
state_dict
,
key_prefix
,
dtype
):
...
...
@@ -57,6 +58,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
context_dim
=
None
use_linear_in_transformer
=
False
video_model
=
False
current_res
=
1
count
=
0
...
...
@@ -99,6 +101,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
if
context_dim
is
None
:
context_dim
=
out
[
1
]
use_linear_in_transformer
=
out
[
2
]
video_model
=
out
[
3
]
else
:
transformer_depth
.
append
(
0
)
...
...
@@ -127,6 +130,19 @@ def detect_unet_config(state_dict, key_prefix, dtype):
unet_config
[
"transformer_depth_middle"
]
=
transformer_depth_middle
unet_config
[
'use_linear_in_transformer'
]
=
use_linear_in_transformer
unet_config
[
"context_dim"
]
=
context_dim
if
video_model
:
unet_config
[
"extra_ff_mix_layer"
]
=
True
unet_config
[
"use_spatial_context"
]
=
True
unet_config
[
"merge_strategy"
]
=
"learned_with_images"
unet_config
[
"merge_factor"
]
=
0.0
unet_config
[
"video_kernel_size"
]
=
[
3
,
1
,
1
]
unet_config
[
"use_temporal_resblock"
]
=
True
unet_config
[
"use_temporal_attention"
]
=
True
else
:
unet_config
[
"use_temporal_resblock"
]
=
False
unet_config
[
"use_temporal_attention"
]
=
False
return
unet_config
def
model_config_from_unet_config
(
unet_config
):
...
...
comfy/model_sampling.py
View file @
871cc20e
import
torch
import
numpy
as
np
from
comfy.ldm.modules.diffusionmodules.util
import
make_beta_schedule
import
math
class
EPS
:
def
calculate_input
(
self
,
sigma
,
noise
):
...
...
@@ -83,3 +83,47 @@ class ModelSamplingDiscrete(torch.nn.Module):
percent
=
1.0
-
percent
return
self
.
sigma
(
torch
.
tensor
(
percent
*
999.0
)).
item
()
class
ModelSamplingContinuousEDM
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_config
=
None
):
super
().
__init__
()
self
.
sigma_data
=
1.0
if
model_config
is
not
None
:
sampling_settings
=
model_config
.
sampling_settings
else
:
sampling_settings
=
{}
sigma_min
=
sampling_settings
.
get
(
"sigma_min"
,
0.002
)
sigma_max
=
sampling_settings
.
get
(
"sigma_max"
,
120.0
)
self
.
set_sigma_range
(
sigma_min
,
sigma_max
)
def
set_sigma_range
(
self
,
sigma_min
,
sigma_max
):
sigmas
=
torch
.
linspace
(
math
.
log
(
sigma_min
),
math
.
log
(
sigma_max
),
1000
).
exp
()
self
.
register_buffer
(
'sigmas'
,
sigmas
)
#for compatibility with some schedulers
self
.
register_buffer
(
'log_sigmas'
,
sigmas
.
log
())
@
property
def
sigma_min
(
self
):
return
self
.
sigmas
[
0
]
@
property
def
sigma_max
(
self
):
return
self
.
sigmas
[
-
1
]
def
timestep
(
self
,
sigma
):
return
0.25
*
sigma
.
log
()
def
sigma
(
self
,
timestep
):
return
(
timestep
/
0.25
).
exp
()
def
percent_to_sigma
(
self
,
percent
):
if
percent
<=
0.0
:
return
999999999.9
if
percent
>=
1.0
:
return
0.0
percent
=
1.0
-
percent
log_sigma_min
=
math
.
log
(
self
.
sigma_min
)
return
math
.
exp
((
math
.
log
(
self
.
sigma_max
)
-
log_sigma_min
)
*
percent
+
log_sigma_min
)
comfy/sd.py
View file @
871cc20e
...
...
@@ -159,7 +159,15 @@ class VAE:
self
.
memory_used_decode
=
lambda
shape
,
dtype
:
(
2178
*
shape
[
2
]
*
shape
[
3
]
*
64
)
*
model_management
.
dtype_size
(
dtype
)
if
config
is
None
:
if
"taesd_decoder.1.weight"
in
sd
:
if
"decoder.mid.block_1.mix_factor"
in
sd
:
encoder_config
=
{
'double_z'
:
True
,
'z_channels'
:
4
,
'resolution'
:
256
,
'in_channels'
:
3
,
'out_ch'
:
3
,
'ch'
:
128
,
'ch_mult'
:
[
1
,
2
,
4
,
4
],
'num_res_blocks'
:
2
,
'attn_resolutions'
:
[],
'dropout'
:
0.0
}
decoder_config
=
encoder_config
.
copy
()
decoder_config
[
"video_kernel_size"
]
=
[
3
,
1
,
1
]
decoder_config
[
"alpha"
]
=
0.0
self
.
first_stage_model
=
AutoencodingEngine
(
regularizer_config
=
{
'target'
:
"comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"
},
encoder_config
=
{
'target'
:
"comfy.ldm.modules.diffusionmodules.model.Encoder"
,
'params'
:
encoder_config
},
decoder_config
=
{
'target'
:
"comfy.ldm.modules.temporal_ae.VideoDecoder"
,
'params'
:
decoder_config
})
elif
"taesd_decoder.1.weight"
in
sd
:
self
.
first_stage_model
=
comfy
.
taesd
.
taesd
.
TAESD
()
else
:
#default SD1.x/SD2.x VAE parameters
...
...
comfy/supported_models.py
View file @
871cc20e
...
...
@@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE):
"model_channels"
:
320
,
"use_linear_in_transformer"
:
False
,
"adm_in_channels"
:
None
,
"use_temporal_attention"
:
False
,
}
unet_extra_config
=
{
...
...
@@ -56,6 +57,7 @@ class SD20(supported_models_base.BASE):
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"adm_in_channels"
:
None
,
"use_temporal_attention"
:
False
,
}
latent_format
=
latent_formats
.
SD15
...
...
@@ -88,6 +90,7 @@ class SD21UnclipL(SD20):
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"adm_in_channels"
:
1536
,
"use_temporal_attention"
:
False
,
}
clip_vision_prefix
=
"embedder.model.visual."
...
...
@@ -100,6 +103,7 @@ class SD21UnclipH(SD20):
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"adm_in_channels"
:
2048
,
"use_temporal_attention"
:
False
,
}
clip_vision_prefix
=
"embedder.model.visual."
...
...
@@ -112,6 +116,7 @@ class SDXLRefiner(supported_models_base.BASE):
"context_dim"
:
1280
,
"adm_in_channels"
:
2560
,
"transformer_depth"
:
[
0
,
0
,
4
,
4
,
4
,
4
,
0
,
0
],
"use_temporal_attention"
:
False
,
}
latent_format
=
latent_formats
.
SDXL
...
...
@@ -148,7 +153,8 @@ class SDXL(supported_models_base.BASE):
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
10
,
10
],
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
"adm_in_channels"
:
2816
,
"use_temporal_attention"
:
False
,
}
latent_format
=
latent_formats
.
SDXL
...
...
@@ -203,8 +209,34 @@ class SSD1B(SDXL):
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
4
,
4
],
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
"adm_in_channels"
:
2816
,
"use_temporal_attention"
:
False
,
}
class
SVD_img2vid
(
supported_models_base
.
BASE
):
unet_config
=
{
"model_channels"
:
320
,
"in_channels"
:
8
,
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
],
"context_dim"
:
1024
,
"adm_in_channels"
:
768
,
"use_temporal_attention"
:
True
,
"use_temporal_resblock"
:
True
}
clip_vision_prefix
=
"conditioner.embedders.0.open_clip.model.visual."
latent_format
=
latent_formats
.
SD15
sampling_settings
=
{
"sigma_max"
:
700.0
,
"sigma_min"
:
0.002
}
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
out
=
model_base
.
SVD_img2vid
(
self
,
device
=
device
)
return
out
def
clip_target
(
self
):
return
None
models
=
[
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
]
models
+=
[
SVD_img2vid
]
comfy_extras/nodes_model_advanced.py
View file @
871cc20e
...
...
@@ -128,6 +128,36 @@ class ModelSamplingDiscrete:
m
.
add_object_patch
(
"model_sampling"
,
model_sampling
)
return
(
m
,
)
class
ModelSamplingContinuousEDM
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"sampling"
:
([
"v_prediction"
,
"eps"
],),
"sigma_max"
:
(
"FLOAT"
,
{
"default"
:
120.0
,
"min"
:
0.0
,
"max"
:
1000.0
,
"step"
:
0.001
,
"round"
:
False
}),
"sigma_min"
:
(
"FLOAT"
,
{
"default"
:
0.002
,
"min"
:
0.0
,
"max"
:
1000.0
,
"step"
:
0.001
,
"round"
:
False
}),
}}
RETURN_TYPES
=
(
"MODEL"
,)
FUNCTION
=
"patch"
CATEGORY
=
"advanced/model"
def
patch
(
self
,
model
,
sampling
,
sigma_max
,
sigma_min
):
m
=
model
.
clone
()
if
sampling
==
"eps"
:
sampling_type
=
comfy
.
model_sampling
.
EPS
elif
sampling
==
"v_prediction"
:
sampling_type
=
comfy
.
model_sampling
.
V_PREDICTION
class
ModelSamplingAdvanced
(
comfy
.
model_sampling
.
ModelSamplingContinuousEDM
,
sampling_type
):
pass
model_sampling
=
ModelSamplingAdvanced
()
model_sampling
.
set_sigma_range
(
sigma_min
,
sigma_max
)
m
.
add_object_patch
(
"model_sampling"
,
model_sampling
)
return
(
m
,
)
class
RescaleCFG
:
@
classmethod
def
INPUT_TYPES
(
s
):
...
...
@@ -169,5 +199,6 @@ class RescaleCFG:
NODE_CLASS_MAPPINGS
=
{
"ModelSamplingDiscrete"
:
ModelSamplingDiscrete
,
"ModelSamplingContinuousEDM"
:
ModelSamplingContinuousEDM
,
"RescaleCFG"
:
RescaleCFG
,
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment