Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
25853d0b
Commit
25853d0b
authored
Jul 30, 2024
by
comfyanonymous
Browse files
Use common function for casting weights to input.
parent
79040635
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
51 additions
and
31 deletions
+51
-31
comfy/ldm/audio/dit.py
comfy/ldm/audio/dit.py
+13
-10
comfy/ldm/aura/mmdit.py
comfy/ldm/aura/mmdit.py
+3
-2
comfy/ldm/cascade/common.py
comfy/ldm/cascade/common.py
+4
-11
comfy/ldm/hydit/models.py
comfy/ldm/hydit/models.py
+3
-2
comfy/ldm/hydit/poolers.py
comfy/ldm/hydit/poolers.py
+3
-3
comfy/ldm/modules/diffusionmodules/mmdit.py
comfy/ldm/modules/diffusionmodules/mmdit.py
+2
-1
comfy/ops.py
comfy/ops.py
+23
-2
No files found.
comfy/ldm/audio/dit.py
View file @
25853d0b
...
...
@@ -9,6 +9,7 @@ from einops import rearrange
from
torch
import
nn
from
torch.nn
import
functional
as
F
import
math
import
comfy.ops
class
FourierFeatures
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
std
=
1.
,
dtype
=
None
,
device
=
None
):
...
...
@@ -18,7 +19,7 @@ class FourierFeatures(nn.Module):
[
out_features
//
2
,
in_features
],
dtype
=
dtype
,
device
=
device
))
def
forward
(
self
,
input
):
f
=
2
*
math
.
pi
*
input
@
self
.
weight
.
T
.
to
(
dtype
=
input
.
dtype
,
device
=
input
.
device
)
f
=
2
*
math
.
pi
*
input
@
comfy
.
ops
.
cast_to_input
(
self
.
weight
.
T
,
input
)
return
torch
.
cat
([
f
.
cos
(),
f
.
sin
()],
dim
=-
1
)
# norms
...
...
@@ -38,9 +39,9 @@ class LayerNorm(nn.Module):
def
forward
(
self
,
x
):
beta
=
self
.
beta
if
self
.
beta
is
not
None
:
beta
=
beta
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
)
return
F
.
layer_norm
(
x
,
x
.
shape
[
-
1
:],
weight
=
self
.
gamma
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
),
bias
=
beta
)
if
beta
is
not
None
:
beta
=
comfy
.
ops
.
cast_to_input
(
beta
,
x
)
return
F
.
layer_norm
(
x
,
x
.
shape
[
-
1
:],
weight
=
comfy
.
ops
.
cast_to_input
(
self
.
gamma
,
x
),
bias
=
beta
)
class
GLU
(
nn
.
Module
):
def
__init__
(
...
...
@@ -123,7 +124,9 @@ class RotaryEmbedding(nn.Module):
scale_base
=
512
,
interpolation_factor
=
1.
,
base
=
10000
,
base_rescale_factor
=
1.
base_rescale_factor
=
1.
,
dtype
=
None
,
device
=
None
,
):
super
().
__init__
()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
...
...
@@ -131,8 +134,8 @@ class RotaryEmbedding(nn.Module):
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base
*=
base_rescale_factor
**
(
dim
/
(
dim
-
2
))
inv_freq
=
1.
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
#
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self
.
register_buffer
(
'inv_freq'
,
torch
.
empty
((
dim
//
2
,),
device
=
device
,
dtype
=
dtype
)
)
assert
interpolation_factor
>=
1.
self
.
interpolation_factor
=
interpolation_factor
...
...
@@ -161,14 +164,14 @@ class RotaryEmbedding(nn.Module):
t
=
t
/
self
.
interpolation_factor
freqs
=
torch
.
einsum
(
'i , j -> i j'
,
t
,
self
.
inv_freq
.
to
(
dtype
=
dtype
,
device
=
device
))
freqs
=
torch
.
einsum
(
'i , j -> i j'
,
t
,
comfy
.
ops
.
cast_to_input
(
self
.
inv_freq
,
t
))
freqs
=
torch
.
cat
((
freqs
,
freqs
),
dim
=
-
1
)
if
self
.
scale
is
None
:
return
freqs
,
1.
power
=
(
torch
.
arange
(
seq_len
,
device
=
device
)
-
(
seq_len
//
2
))
/
self
.
scale_base
scale
=
self
.
scale
.
to
(
dtype
=
dtype
,
device
=
device
)
**
rearrange
(
power
,
'n -> n 1'
)
scale
=
comfy
.
ops
.
cast_to_input
(
self
.
scale
,
t
)
**
rearrange
(
power
,
'n -> n 1'
)
scale
=
torch
.
cat
((
scale
,
scale
),
dim
=
-
1
)
return
freqs
,
scale
...
...
@@ -568,7 +571,7 @@ class ContinuousTransformer(nn.Module):
self
.
project_out
=
operations
.
Linear
(
dim
,
dim_out
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
if
dim_out
is
not
None
else
nn
.
Identity
()
if
rotary_pos_emb
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
max
(
dim_heads
//
2
,
32
))
self
.
rotary_pos_emb
=
RotaryEmbedding
(
max
(
dim_heads
//
2
,
32
)
,
device
=
device
,
dtype
=
dtype
)
else
:
self
.
rotary_pos_emb
=
None
...
...
comfy/ldm/aura/mmdit.py
View file @
25853d0b
...
...
@@ -8,6 +8,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
comfy.ldm.modules.attention
import
optimized_attention
import
comfy.ops
def
modulate
(
x
,
shift
,
scale
):
return
x
*
(
1
+
scale
.
unsqueeze
(
1
))
+
shift
.
unsqueeze
(
1
)
...
...
@@ -427,7 +428,7 @@ class MMDiT(nn.Module):
max_dim
=
max
(
h
,
w
)
cur_dim
=
self
.
h_max
pos_encoding
=
self
.
positional_encoding
.
reshape
(
1
,
cur_dim
,
cur_dim
,
-
1
)
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
pos_encoding
=
comfy
.
ops
.
cast_to_input
(
self
.
positional_encoding
.
reshape
(
1
,
cur_dim
,
cur_dim
,
-
1
)
,
x
)
if
max_dim
>
cur_dim
:
pos_encoding
=
F
.
interpolate
(
pos_encoding
.
movedim
(
-
1
,
1
),
(
max_dim
,
max_dim
),
mode
=
"bilinear"
).
movedim
(
1
,
-
1
)
...
...
@@ -455,7 +456,7 @@ class MMDiT(nn.Module):
t
=
timestep
c
=
self
.
cond_seq_linear
(
c_seq
)
# B, T_c, D
c
=
torch
.
cat
([
self
.
register_tokens
.
to
(
device
=
c
.
device
,
dtype
=
c
.
dtype
).
repeat
(
c
.
size
(
0
),
1
,
1
),
c
],
dim
=
1
)
c
=
torch
.
cat
([
comfy
.
ops
.
cast_to_input
(
self
.
register_tokens
,
c
).
repeat
(
c
.
size
(
0
),
1
,
1
),
c
],
dim
=
1
)
global_cond
=
self
.
t_embedder
(
t
,
x
.
dtype
)
# B, D
...
...
comfy/ldm/cascade/common.py
View file @
25853d0b
...
...
@@ -19,14 +19,7 @@
import
torch
import
torch.nn
as
nn
from
comfy.ldm.modules.attention
import
optimized_attention
class
Linear
(
torch
.
nn
.
Linear
):
def
reset_parameters
(
self
):
return
None
class
Conv2d
(
torch
.
nn
.
Conv2d
):
def
reset_parameters
(
self
):
return
None
import
comfy.ops
class
OptimizedAttention
(
nn
.
Module
):
def
__init__
(
self
,
c
,
nhead
,
dropout
=
0.0
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
...
...
@@ -78,13 +71,13 @@ class GlobalResponseNorm(nn.Module):
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
def
__init__
(
self
,
dim
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
self
.
gamma
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
1
,
dim
,
dtype
=
dtype
,
device
=
device
))
self
.
beta
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
1
,
dim
,
dtype
=
dtype
,
device
=
device
))
self
.
gamma
=
nn
.
Parameter
(
torch
.
empty
(
1
,
1
,
1
,
dim
,
dtype
=
dtype
,
device
=
device
))
self
.
beta
=
nn
.
Parameter
(
torch
.
empty
(
1
,
1
,
1
,
dim
,
dtype
=
dtype
,
device
=
device
))
def
forward
(
self
,
x
):
Gx
=
torch
.
norm
(
x
,
p
=
2
,
dim
=
(
1
,
2
),
keepdim
=
True
)
Nx
=
Gx
/
(
Gx
.
mean
(
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
return
self
.
gamma
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
*
(
x
*
Nx
)
+
self
.
beta
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
+
x
return
comfy
.
ops
.
cast_to_input
(
self
.
gamma
,
x
)
*
(
x
*
Nx
)
+
comfy
.
ops
.
cast_to_input
(
self
.
beta
,
x
)
+
x
class
ResBlock
(
nn
.
Module
):
...
...
comfy/ldm/hydit/models.py
View file @
25853d0b
...
...
@@ -4,6 +4,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
comfy.ops
from
comfy.ldm.modules.diffusionmodules.mmdit
import
Mlp
,
TimestepEmbedder
,
PatchEmbed
,
RMSNorm
from
comfy.ldm.modules.diffusionmodules.util
import
timestep_embedding
from
torch.utils
import
checkpoint
...
...
@@ -234,7 +235,7 @@ class HunYuanDiT(nn.Module):
if
self
.
use_style_cond
:
# Here we use a default learned embedder layer for future extension.
self
.
style_embedder
=
nn
.
Embedding
(
1
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
self
.
style_embedder
=
operations
.
Embedding
(
1
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
self
.
extra_in_dim
+=
hidden_size
# Text embedding for `add`
...
...
@@ -321,7 +322,7 @@ class HunYuanDiT(nn.Module):
b_t5
,
l_t5
,
c_t5
=
text_states_t5
.
shape
text_states_t5
=
self
.
mlp_t5
(
text_states_t5
.
view
(
-
1
,
c_t5
)).
view
(
b_t5
,
l_t5
,
-
1
)
padding
=
self
.
text_embedding_padding
.
to
(
text_states
)
padding
=
comfy
.
ops
.
cast_to_input
(
self
.
text_embedding_padding
,
text_states
)
text_states
[:,
-
self
.
text_len
:]
=
torch
.
where
(
text_states_mask
[:,
-
self
.
text_len
:].
unsqueeze
(
2
),
text_states
[:,
-
self
.
text_len
:],
padding
[:
self
.
text_len
])
text_states_t5
[:,
-
self
.
text_len_t5
:]
=
torch
.
where
(
text_states_t5_mask
[:,
-
self
.
text_len_t5
:].
unsqueeze
(
2
),
text_states_t5
[:,
-
self
.
text_len_t5
:],
padding
[
self
.
text_len
:])
...
...
comfy/ldm/hydit/poolers.py
View file @
25853d0b
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
comfy.ldm.modules.attention
import
optimized_attention
#TODO
from
comfy.ldm.modules.attention
import
optimized_attention
import
comfy.ops
class
AttentionPool
(
nn
.
Module
):
def
__init__
(
self
,
spacial_dim
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
output_dim
:
int
=
None
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
...
...
@@ -19,7 +19,7 @@ class AttentionPool(nn.Module):
x
=
x
[:,:
self
.
positional_embedding
.
shape
[
0
]
-
1
]
x
=
x
.
permute
(
1
,
0
,
2
)
# NLC -> LNC
x
=
torch
.
cat
([
x
.
mean
(
dim
=
0
,
keepdim
=
True
),
x
],
dim
=
0
)
# (L+1)NC
x
=
x
+
self
.
positional_embedding
[:,
None
,
:]
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
)
# (L+1)NC
x
=
x
+
comfy
.
ops
.
cast_to_input
(
self
.
positional_embedding
[:,
None
,
:]
,
x
)
# (L+1)NC
q
=
self
.
q_proj
(
x
[:
1
])
k
=
self
.
k_proj
(
x
)
...
...
comfy/ldm/modules/diffusionmodules/mmdit.py
View file @
25853d0b
...
...
@@ -8,6 +8,7 @@ import torch.nn as nn
from
..
import
attention
from
einops
import
rearrange
,
repeat
from
.util
import
timestep_embedding
import
comfy.ops
def
default
(
x
,
y
):
if
x
is
not
None
:
...
...
@@ -926,7 +927,7 @@ class MMDiT(nn.Module):
context
=
self
.
context_processor
(
context
)
hw
=
x
.
shape
[
-
2
:]
x
=
self
.
x_embedder
(
x
)
+
self
.
cropped_pos_embed
(
hw
,
device
=
x
.
device
)
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x
=
self
.
x_embedder
(
x
)
+
comfy
.
ops
.
cast_to_input
(
self
.
cropped_pos_embed
(
hw
,
device
=
x
.
device
)
,
x
)
c
=
self
.
t_embedder
(
t
,
dtype
=
x
.
dtype
)
# (N, D)
if
y
is
not
None
and
self
.
y_embedder
is
not
None
:
y
=
self
.
y_embedder
(
y
)
# (N, D)
...
...
comfy/ops.py
View file @
25853d0b
...
...
@@ -19,14 +19,17 @@
import
torch
import
comfy.model_management
def
cast_to_input
(
weight
,
input
,
non_blocking
=
False
):
return
weight
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
def
cast_bias_weight
(
s
,
input
):
bias
=
None
non_blocking
=
comfy
.
model_management
.
device_should_use_non_blocking
(
input
.
device
)
if
s
.
bias
is
not
None
:
bias
=
s
.
bias
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
bias
=
cast_to_input
(
s
.
bias
,
input
,
non_blocking
=
non_blocking
)
if
s
.
bias_function
is
not
None
:
bias
=
s
.
bias_function
(
bias
)
weight
=
s
.
weight
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
weight
=
cast_to_input
(
s
.
weight
,
input
,
non_blocking
=
non_blocking
)
if
s
.
weight_function
is
not
None
:
weight
=
s
.
weight_function
(
weight
)
return
weight
,
bias
...
...
@@ -168,6 +171,21 @@ class disable_weight_init:
else
:
return
super
().
forward
(
*
args
,
**
kwargs
)
class
Embedding
(
torch
.
nn
.
Embedding
,
CastWeightBiasOp
):
def
reset_parameters
(
self
):
self
.
bias
=
None
return
None
def
forward_comfy_cast_weights
(
self
,
input
):
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
return
torch
.
nn
.
functional
.
embedding
(
input
,
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
comfy_cast_weights
:
return
self
.
forward_comfy_cast_weights
(
*
args
,
**
kwargs
)
else
:
return
super
().
forward
(
*
args
,
**
kwargs
)
@
classmethod
def
conv_nd
(
s
,
dims
,
*
args
,
**
kwargs
):
if
dims
==
2
:
...
...
@@ -202,3 +220,6 @@ class manual_cast(disable_weight_init):
class
ConvTranspose1d
(
disable_weight_init
.
ConvTranspose1d
):
comfy_cast_weights
=
True
class
Embedding
(
disable_weight_init
.
Embedding
):
comfy_cast_weights
=
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment