Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
25853d0b
Commit
25853d0b
authored
Jul 30, 2024
by
comfyanonymous
Browse files
Use common function for casting weights to input.
parent
79040635
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
51 additions
and
31 deletions
+51
-31
comfy/ldm/audio/dit.py
comfy/ldm/audio/dit.py
+13
-10
comfy/ldm/aura/mmdit.py
comfy/ldm/aura/mmdit.py
+3
-2
comfy/ldm/cascade/common.py
comfy/ldm/cascade/common.py
+4
-11
comfy/ldm/hydit/models.py
comfy/ldm/hydit/models.py
+3
-2
comfy/ldm/hydit/poolers.py
comfy/ldm/hydit/poolers.py
+3
-3
comfy/ldm/modules/diffusionmodules/mmdit.py
comfy/ldm/modules/diffusionmodules/mmdit.py
+2
-1
comfy/ops.py
comfy/ops.py
+23
-2
No files found.
comfy/ldm/audio/dit.py
View file @
25853d0b
...
@@ -9,6 +9,7 @@ from einops import rearrange
...
@@ -9,6 +9,7 @@ from einops import rearrange
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
import
math
import
math
import
comfy.ops
class
FourierFeatures
(
nn
.
Module
):
class
FourierFeatures
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
std
=
1.
,
dtype
=
None
,
device
=
None
):
def
__init__
(
self
,
in_features
,
out_features
,
std
=
1.
,
dtype
=
None
,
device
=
None
):
...
@@ -18,7 +19,7 @@ class FourierFeatures(nn.Module):
...
@@ -18,7 +19,7 @@ class FourierFeatures(nn.Module):
[
out_features
//
2
,
in_features
],
dtype
=
dtype
,
device
=
device
))
[
out_features
//
2
,
in_features
],
dtype
=
dtype
,
device
=
device
))
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
f
=
2
*
math
.
pi
*
input
@
self
.
weight
.
T
.
to
(
dtype
=
input
.
dtype
,
device
=
input
.
device
)
f
=
2
*
math
.
pi
*
input
@
comfy
.
ops
.
cast_to_input
(
self
.
weight
.
T
,
input
)
return
torch
.
cat
([
f
.
cos
(),
f
.
sin
()],
dim
=-
1
)
return
torch
.
cat
([
f
.
cos
(),
f
.
sin
()],
dim
=-
1
)
# norms
# norms
...
@@ -38,9 +39,9 @@ class LayerNorm(nn.Module):
...
@@ -38,9 +39,9 @@ class LayerNorm(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
beta
=
self
.
beta
beta
=
self
.
beta
if
self
.
beta
is
not
None
:
if
beta
is
not
None
:
beta
=
beta
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
)
beta
=
comfy
.
ops
.
cast_to_input
(
beta
,
x
)
return
F
.
layer_norm
(
x
,
x
.
shape
[
-
1
:],
weight
=
self
.
gamma
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
),
bias
=
beta
)
return
F
.
layer_norm
(
x
,
x
.
shape
[
-
1
:],
weight
=
comfy
.
ops
.
cast_to_input
(
self
.
gamma
,
x
),
bias
=
beta
)
class
GLU
(
nn
.
Module
):
class
GLU
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -123,7 +124,9 @@ class RotaryEmbedding(nn.Module):
...
@@ -123,7 +124,9 @@ class RotaryEmbedding(nn.Module):
scale_base
=
512
,
scale_base
=
512
,
interpolation_factor
=
1.
,
interpolation_factor
=
1.
,
base
=
10000
,
base
=
10000
,
base_rescale_factor
=
1.
base_rescale_factor
=
1.
,
dtype
=
None
,
device
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
...
@@ -131,8 +134,8 @@ class RotaryEmbedding(nn.Module):
...
@@ -131,8 +134,8 @@ class RotaryEmbedding(nn.Module):
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base
*=
base_rescale_factor
**
(
dim
/
(
dim
-
2
))
base
*=
base_rescale_factor
**
(
dim
/
(
dim
-
2
))
inv_freq
=
1.
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
#
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
self
.
register_buffer
(
'inv_freq'
,
torch
.
empty
((
dim
//
2
,),
device
=
device
,
dtype
=
dtype
)
)
assert
interpolation_factor
>=
1.
assert
interpolation_factor
>=
1.
self
.
interpolation_factor
=
interpolation_factor
self
.
interpolation_factor
=
interpolation_factor
...
@@ -161,14 +164,14 @@ class RotaryEmbedding(nn.Module):
...
@@ -161,14 +164,14 @@ class RotaryEmbedding(nn.Module):
t
=
t
/
self
.
interpolation_factor
t
=
t
/
self
.
interpolation_factor
freqs
=
torch
.
einsum
(
'i , j -> i j'
,
t
,
self
.
inv_freq
.
to
(
dtype
=
dtype
,
device
=
device
))
freqs
=
torch
.
einsum
(
'i , j -> i j'
,
t
,
comfy
.
ops
.
cast_to_input
(
self
.
inv_freq
,
t
))
freqs
=
torch
.
cat
((
freqs
,
freqs
),
dim
=
-
1
)
freqs
=
torch
.
cat
((
freqs
,
freqs
),
dim
=
-
1
)
if
self
.
scale
is
None
:
if
self
.
scale
is
None
:
return
freqs
,
1.
return
freqs
,
1.
power
=
(
torch
.
arange
(
seq_len
,
device
=
device
)
-
(
seq_len
//
2
))
/
self
.
scale_base
power
=
(
torch
.
arange
(
seq_len
,
device
=
device
)
-
(
seq_len
//
2
))
/
self
.
scale_base
scale
=
self
.
scale
.
to
(
dtype
=
dtype
,
device
=
device
)
**
rearrange
(
power
,
'n -> n 1'
)
scale
=
comfy
.
ops
.
cast_to_input
(
self
.
scale
,
t
)
**
rearrange
(
power
,
'n -> n 1'
)
scale
=
torch
.
cat
((
scale
,
scale
),
dim
=
-
1
)
scale
=
torch
.
cat
((
scale
,
scale
),
dim
=
-
1
)
return
freqs
,
scale
return
freqs
,
scale
...
@@ -568,7 +571,7 @@ class ContinuousTransformer(nn.Module):
...
@@ -568,7 +571,7 @@ class ContinuousTransformer(nn.Module):
self
.
project_out
=
operations
.
Linear
(
dim
,
dim_out
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
if
dim_out
is
not
None
else
nn
.
Identity
()
self
.
project_out
=
operations
.
Linear
(
dim
,
dim_out
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
if
dim_out
is
not
None
else
nn
.
Identity
()
if
rotary_pos_emb
:
if
rotary_pos_emb
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
max
(
dim_heads
//
2
,
32
))
self
.
rotary_pos_emb
=
RotaryEmbedding
(
max
(
dim_heads
//
2
,
32
)
,
device
=
device
,
dtype
=
dtype
)
else
:
else
:
self
.
rotary_pos_emb
=
None
self
.
rotary_pos_emb
=
None
...
...
comfy/ldm/aura/mmdit.py
View file @
25853d0b
...
@@ -8,6 +8,7 @@ import torch.nn as nn
...
@@ -8,6 +8,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
comfy.ldm.modules.attention
import
optimized_attention
from
comfy.ldm.modules.attention
import
optimized_attention
import
comfy.ops
def
modulate
(
x
,
shift
,
scale
):
def
modulate
(
x
,
shift
,
scale
):
return
x
*
(
1
+
scale
.
unsqueeze
(
1
))
+
shift
.
unsqueeze
(
1
)
return
x
*
(
1
+
scale
.
unsqueeze
(
1
))
+
shift
.
unsqueeze
(
1
)
...
@@ -427,7 +428,7 @@ class MMDiT(nn.Module):
...
@@ -427,7 +428,7 @@ class MMDiT(nn.Module):
max_dim
=
max
(
h
,
w
)
max_dim
=
max
(
h
,
w
)
cur_dim
=
self
.
h_max
cur_dim
=
self
.
h_max
pos_encoding
=
self
.
positional_encoding
.
reshape
(
1
,
cur_dim
,
cur_dim
,
-
1
)
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
pos_encoding
=
comfy
.
ops
.
cast_to_input
(
self
.
positional_encoding
.
reshape
(
1
,
cur_dim
,
cur_dim
,
-
1
)
,
x
)
if
max_dim
>
cur_dim
:
if
max_dim
>
cur_dim
:
pos_encoding
=
F
.
interpolate
(
pos_encoding
.
movedim
(
-
1
,
1
),
(
max_dim
,
max_dim
),
mode
=
"bilinear"
).
movedim
(
1
,
-
1
)
pos_encoding
=
F
.
interpolate
(
pos_encoding
.
movedim
(
-
1
,
1
),
(
max_dim
,
max_dim
),
mode
=
"bilinear"
).
movedim
(
1
,
-
1
)
...
@@ -455,7 +456,7 @@ class MMDiT(nn.Module):
...
@@ -455,7 +456,7 @@ class MMDiT(nn.Module):
t
=
timestep
t
=
timestep
c
=
self
.
cond_seq_linear
(
c_seq
)
# B, T_c, D
c
=
self
.
cond_seq_linear
(
c_seq
)
# B, T_c, D
c
=
torch
.
cat
([
self
.
register_tokens
.
to
(
device
=
c
.
device
,
dtype
=
c
.
dtype
).
repeat
(
c
.
size
(
0
),
1
,
1
),
c
],
dim
=
1
)
c
=
torch
.
cat
([
comfy
.
ops
.
cast_to_input
(
self
.
register_tokens
,
c
).
repeat
(
c
.
size
(
0
),
1
,
1
),
c
],
dim
=
1
)
global_cond
=
self
.
t_embedder
(
t
,
x
.
dtype
)
# B, D
global_cond
=
self
.
t_embedder
(
t
,
x
.
dtype
)
# B, D
...
...
comfy/ldm/cascade/common.py
View file @
25853d0b
...
@@ -19,14 +19,7 @@
...
@@ -19,14 +19,7 @@
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
comfy.ldm.modules.attention
import
optimized_attention
from
comfy.ldm.modules.attention
import
optimized_attention
import
comfy.ops
class
Linear
(
torch
.
nn
.
Linear
):
def
reset_parameters
(
self
):
return
None
class
Conv2d
(
torch
.
nn
.
Conv2d
):
def
reset_parameters
(
self
):
return
None
class
OptimizedAttention
(
nn
.
Module
):
class
OptimizedAttention
(
nn
.
Module
):
def
__init__
(
self
,
c
,
nhead
,
dropout
=
0.0
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
def
__init__
(
self
,
c
,
nhead
,
dropout
=
0.0
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
...
@@ -78,13 +71,13 @@ class GlobalResponseNorm(nn.Module):
...
@@ -78,13 +71,13 @@ class GlobalResponseNorm(nn.Module):
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
def
__init__
(
self
,
dim
,
dtype
=
None
,
device
=
None
):
def
__init__
(
self
,
dim
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
gamma
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
1
,
dim
,
dtype
=
dtype
,
device
=
device
))
self
.
gamma
=
nn
.
Parameter
(
torch
.
empty
(
1
,
1
,
1
,
dim
,
dtype
=
dtype
,
device
=
device
))
self
.
beta
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
1
,
dim
,
dtype
=
dtype
,
device
=
device
))
self
.
beta
=
nn
.
Parameter
(
torch
.
empty
(
1
,
1
,
1
,
dim
,
dtype
=
dtype
,
device
=
device
))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
Gx
=
torch
.
norm
(
x
,
p
=
2
,
dim
=
(
1
,
2
),
keepdim
=
True
)
Gx
=
torch
.
norm
(
x
,
p
=
2
,
dim
=
(
1
,
2
),
keepdim
=
True
)
Nx
=
Gx
/
(
Gx
.
mean
(
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
Nx
=
Gx
/
(
Gx
.
mean
(
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
return
self
.
gamma
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
*
(
x
*
Nx
)
+
self
.
beta
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
+
x
return
comfy
.
ops
.
cast_to_input
(
self
.
gamma
,
x
)
*
(
x
*
Nx
)
+
comfy
.
ops
.
cast_to_input
(
self
.
beta
,
x
)
+
x
class
ResBlock
(
nn
.
Module
):
class
ResBlock
(
nn
.
Module
):
...
...
comfy/ldm/hydit/models.py
View file @
25853d0b
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
comfy.ops
from
comfy.ldm.modules.diffusionmodules.mmdit
import
Mlp
,
TimestepEmbedder
,
PatchEmbed
,
RMSNorm
from
comfy.ldm.modules.diffusionmodules.mmdit
import
Mlp
,
TimestepEmbedder
,
PatchEmbed
,
RMSNorm
from
comfy.ldm.modules.diffusionmodules.util
import
timestep_embedding
from
comfy.ldm.modules.diffusionmodules.util
import
timestep_embedding
from
torch.utils
import
checkpoint
from
torch.utils
import
checkpoint
...
@@ -234,7 +235,7 @@ class HunYuanDiT(nn.Module):
...
@@ -234,7 +235,7 @@ class HunYuanDiT(nn.Module):
if
self
.
use_style_cond
:
if
self
.
use_style_cond
:
# Here we use a default learned embedder layer for future extension.
# Here we use a default learned embedder layer for future extension.
self
.
style_embedder
=
nn
.
Embedding
(
1
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
self
.
style_embedder
=
operations
.
Embedding
(
1
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
self
.
extra_in_dim
+=
hidden_size
self
.
extra_in_dim
+=
hidden_size
# Text embedding for `add`
# Text embedding for `add`
...
@@ -321,7 +322,7 @@ class HunYuanDiT(nn.Module):
...
@@ -321,7 +322,7 @@ class HunYuanDiT(nn.Module):
b_t5
,
l_t5
,
c_t5
=
text_states_t5
.
shape
b_t5
,
l_t5
,
c_t5
=
text_states_t5
.
shape
text_states_t5
=
self
.
mlp_t5
(
text_states_t5
.
view
(
-
1
,
c_t5
)).
view
(
b_t5
,
l_t5
,
-
1
)
text_states_t5
=
self
.
mlp_t5
(
text_states_t5
.
view
(
-
1
,
c_t5
)).
view
(
b_t5
,
l_t5
,
-
1
)
padding
=
self
.
text_embedding_padding
.
to
(
text_states
)
padding
=
comfy
.
ops
.
cast_to_input
(
self
.
text_embedding_padding
,
text_states
)
text_states
[:,
-
self
.
text_len
:]
=
torch
.
where
(
text_states_mask
[:,
-
self
.
text_len
:].
unsqueeze
(
2
),
text_states
[:,
-
self
.
text_len
:],
padding
[:
self
.
text_len
])
text_states
[:,
-
self
.
text_len
:]
=
torch
.
where
(
text_states_mask
[:,
-
self
.
text_len
:].
unsqueeze
(
2
),
text_states
[:,
-
self
.
text_len
:],
padding
[:
self
.
text_len
])
text_states_t5
[:,
-
self
.
text_len_t5
:]
=
torch
.
where
(
text_states_t5_mask
[:,
-
self
.
text_len_t5
:].
unsqueeze
(
2
),
text_states_t5
[:,
-
self
.
text_len_t5
:],
padding
[
self
.
text_len
:])
text_states_t5
[:,
-
self
.
text_len_t5
:]
=
torch
.
where
(
text_states_t5_mask
[:,
-
self
.
text_len_t5
:].
unsqueeze
(
2
),
text_states_t5
[:,
-
self
.
text_len_t5
:],
padding
[
self
.
text_len
:])
...
...
comfy/ldm/hydit/poolers.py
View file @
25853d0b
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
comfy.ldm.modules.attention
import
optimized_attention
#TODO
from
comfy.ldm.modules.attention
import
optimized_attention
import
comfy.ops
class
AttentionPool
(
nn
.
Module
):
class
AttentionPool
(
nn
.
Module
):
def
__init__
(
self
,
spacial_dim
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
output_dim
:
int
=
None
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
def
__init__
(
self
,
spacial_dim
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
output_dim
:
int
=
None
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
...
@@ -19,7 +19,7 @@ class AttentionPool(nn.Module):
...
@@ -19,7 +19,7 @@ class AttentionPool(nn.Module):
x
=
x
[:,:
self
.
positional_embedding
.
shape
[
0
]
-
1
]
x
=
x
[:,:
self
.
positional_embedding
.
shape
[
0
]
-
1
]
x
=
x
.
permute
(
1
,
0
,
2
)
# NLC -> LNC
x
=
x
.
permute
(
1
,
0
,
2
)
# NLC -> LNC
x
=
torch
.
cat
([
x
.
mean
(
dim
=
0
,
keepdim
=
True
),
x
],
dim
=
0
)
# (L+1)NC
x
=
torch
.
cat
([
x
.
mean
(
dim
=
0
,
keepdim
=
True
),
x
],
dim
=
0
)
# (L+1)NC
x
=
x
+
self
.
positional_embedding
[:,
None
,
:]
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
)
# (L+1)NC
x
=
x
+
comfy
.
ops
.
cast_to_input
(
self
.
positional_embedding
[:,
None
,
:]
,
x
)
# (L+1)NC
q
=
self
.
q_proj
(
x
[:
1
])
q
=
self
.
q_proj
(
x
[:
1
])
k
=
self
.
k_proj
(
x
)
k
=
self
.
k_proj
(
x
)
...
...
comfy/ldm/modules/diffusionmodules/mmdit.py
View file @
25853d0b
...
@@ -8,6 +8,7 @@ import torch.nn as nn
...
@@ -8,6 +8,7 @@ import torch.nn as nn
from
..
import
attention
from
..
import
attention
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
.util
import
timestep_embedding
from
.util
import
timestep_embedding
import
comfy.ops
def
default
(
x
,
y
):
def
default
(
x
,
y
):
if
x
is
not
None
:
if
x
is
not
None
:
...
@@ -926,7 +927,7 @@ class MMDiT(nn.Module):
...
@@ -926,7 +927,7 @@ class MMDiT(nn.Module):
context
=
self
.
context_processor
(
context
)
context
=
self
.
context_processor
(
context
)
hw
=
x
.
shape
[
-
2
:]
hw
=
x
.
shape
[
-
2
:]
x
=
self
.
x_embedder
(
x
)
+
self
.
cropped_pos_embed
(
hw
,
device
=
x
.
device
)
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x
=
self
.
x_embedder
(
x
)
+
comfy
.
ops
.
cast_to_input
(
self
.
cropped_pos_embed
(
hw
,
device
=
x
.
device
)
,
x
)
c
=
self
.
t_embedder
(
t
,
dtype
=
x
.
dtype
)
# (N, D)
c
=
self
.
t_embedder
(
t
,
dtype
=
x
.
dtype
)
# (N, D)
if
y
is
not
None
and
self
.
y_embedder
is
not
None
:
if
y
is
not
None
and
self
.
y_embedder
is
not
None
:
y
=
self
.
y_embedder
(
y
)
# (N, D)
y
=
self
.
y_embedder
(
y
)
# (N, D)
...
...
comfy/ops.py
View file @
25853d0b
...
@@ -19,14 +19,17 @@
...
@@ -19,14 +19,17 @@
import
torch
import
torch
import
comfy.model_management
import
comfy.model_management
def
cast_to_input
(
weight
,
input
,
non_blocking
=
False
):
return
weight
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
def
cast_bias_weight
(
s
,
input
):
def
cast_bias_weight
(
s
,
input
):
bias
=
None
bias
=
None
non_blocking
=
comfy
.
model_management
.
device_should_use_non_blocking
(
input
.
device
)
non_blocking
=
comfy
.
model_management
.
device_should_use_non_blocking
(
input
.
device
)
if
s
.
bias
is
not
None
:
if
s
.
bias
is
not
None
:
bias
=
s
.
bias
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
bias
=
cast_to_input
(
s
.
bias
,
input
,
non_blocking
=
non_blocking
)
if
s
.
bias_function
is
not
None
:
if
s
.
bias_function
is
not
None
:
bias
=
s
.
bias_function
(
bias
)
bias
=
s
.
bias_function
(
bias
)
weight
=
s
.
weight
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
weight
=
cast_to_input
(
s
.
weight
,
input
,
non_blocking
=
non_blocking
)
if
s
.
weight_function
is
not
None
:
if
s
.
weight_function
is
not
None
:
weight
=
s
.
weight_function
(
weight
)
weight
=
s
.
weight_function
(
weight
)
return
weight
,
bias
return
weight
,
bias
...
@@ -168,6 +171,21 @@ class disable_weight_init:
...
@@ -168,6 +171,21 @@ class disable_weight_init:
else
:
else
:
return
super
().
forward
(
*
args
,
**
kwargs
)
return
super
().
forward
(
*
args
,
**
kwargs
)
class
Embedding
(
torch
.
nn
.
Embedding
,
CastWeightBiasOp
):
def
reset_parameters
(
self
):
self
.
bias
=
None
return
None
def
forward_comfy_cast_weights
(
self
,
input
):
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
return
torch
.
nn
.
functional
.
embedding
(
input
,
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
comfy_cast_weights
:
return
self
.
forward_comfy_cast_weights
(
*
args
,
**
kwargs
)
else
:
return
super
().
forward
(
*
args
,
**
kwargs
)
@
classmethod
@
classmethod
def
conv_nd
(
s
,
dims
,
*
args
,
**
kwargs
):
def
conv_nd
(
s
,
dims
,
*
args
,
**
kwargs
):
if
dims
==
2
:
if
dims
==
2
:
...
@@ -202,3 +220,6 @@ class manual_cast(disable_weight_init):
...
@@ -202,3 +220,6 @@ class manual_cast(disable_weight_init):
class
ConvTranspose1d
(
disable_weight_init
.
ConvTranspose1d
):
class
ConvTranspose1d
(
disable_weight_init
.
ConvTranspose1d
):
comfy_cast_weights
=
True
comfy_cast_weights
=
True
class
Embedding
(
disable_weight_init
.
Embedding
):
comfy_cast_weights
=
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment