Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
25853d0b
"...composable_kernel_onnx.git" did not exist on "d6d9a8e4cee89feef6758f825cfea1588fec16da"
Commit
25853d0b
authored
Jul 30, 2024
by
comfyanonymous
Browse files
Use common function for casting weights to input.
parent
79040635
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
51 additions
and
31 deletions
+51
-31
comfy/ldm/audio/dit.py
comfy/ldm/audio/dit.py
+13
-10
comfy/ldm/aura/mmdit.py
comfy/ldm/aura/mmdit.py
+3
-2
comfy/ldm/cascade/common.py
comfy/ldm/cascade/common.py
+4
-11
comfy/ldm/hydit/models.py
comfy/ldm/hydit/models.py
+3
-2
comfy/ldm/hydit/poolers.py
comfy/ldm/hydit/poolers.py
+3
-3
comfy/ldm/modules/diffusionmodules/mmdit.py
comfy/ldm/modules/diffusionmodules/mmdit.py
+2
-1
comfy/ops.py
comfy/ops.py
+23
-2
No files found.
comfy/ldm/audio/dit.py
View file @
25853d0b
...
@@ -9,6 +9,7 @@ from einops import rearrange
...
@@ -9,6 +9,7 @@ from einops import rearrange
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
import
math
import
math
import
comfy.ops
class
FourierFeatures
(
nn
.
Module
):
class
FourierFeatures
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
std
=
1.
,
dtype
=
None
,
device
=
None
):
def
__init__
(
self
,
in_features
,
out_features
,
std
=
1.
,
dtype
=
None
,
device
=
None
):
...
@@ -18,7 +19,7 @@ class FourierFeatures(nn.Module):
...
@@ -18,7 +19,7 @@ class FourierFeatures(nn.Module):
[
out_features
//
2
,
in_features
],
dtype
=
dtype
,
device
=
device
))
[
out_features
//
2
,
in_features
],
dtype
=
dtype
,
device
=
device
))
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
f
=
2
*
math
.
pi
*
input
@
self
.
weight
.
T
.
to
(
dtype
=
input
.
dtype
,
device
=
input
.
device
)
f
=
2
*
math
.
pi
*
input
@
comfy
.
ops
.
cast_to_input
(
self
.
weight
.
T
,
input
)
return
torch
.
cat
([
f
.
cos
(),
f
.
sin
()],
dim
=-
1
)
return
torch
.
cat
([
f
.
cos
(),
f
.
sin
()],
dim
=-
1
)
# norms
# norms
...
@@ -38,9 +39,9 @@ class LayerNorm(nn.Module):
...
@@ -38,9 +39,9 @@ class LayerNorm(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
beta
=
self
.
beta
beta
=
self
.
beta
if
self
.
beta
is
not
None
:
if
beta
is
not
None
:
beta
=
beta
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
)
beta
=
comfy
.
ops
.
cast_to_input
(
beta
,
x
)
return
F
.
layer_norm
(
x
,
x
.
shape
[
-
1
:],
weight
=
self
.
gamma
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
),
bias
=
beta
)
return
F
.
layer_norm
(
x
,
x
.
shape
[
-
1
:],
weight
=
comfy
.
ops
.
cast_to_input
(
self
.
gamma
,
x
),
bias
=
beta
)
class
GLU
(
nn
.
Module
):
class
GLU
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -123,7 +124,9 @@ class RotaryEmbedding(nn.Module):
...
@@ -123,7 +124,9 @@ class RotaryEmbedding(nn.Module):
scale_base
=
512
,
scale_base
=
512
,
interpolation_factor
=
1.
,
interpolation_factor
=
1.
,
base
=
10000
,
base
=
10000
,
base_rescale_factor
=
1.
base_rescale_factor
=
1.
,
dtype
=
None
,
device
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
...
@@ -131,8 +134,8 @@ class RotaryEmbedding(nn.Module):
...
@@ -131,8 +134,8 @@ class RotaryEmbedding(nn.Module):
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base
*=
base_rescale_factor
**
(
dim
/
(
dim
-
2
))
base
*=
base_rescale_factor
**
(
dim
/
(
dim
-
2
))
inv_freq
=
1.
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
#
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
self
.
register_buffer
(
'inv_freq'
,
torch
.
empty
((
dim
//
2
,),
device
=
device
,
dtype
=
dtype
)
)
assert
interpolation_factor
>=
1.
assert
interpolation_factor
>=
1.
self
.
interpolation_factor
=
interpolation_factor
self
.
interpolation_factor
=
interpolation_factor
...
@@ -161,14 +164,14 @@ class RotaryEmbedding(nn.Module):
...
@@ -161,14 +164,14 @@ class RotaryEmbedding(nn.Module):
t
=
t
/
self
.
interpolation_factor
t
=
t
/
self
.
interpolation_factor
freqs
=
torch
.
einsum
(
'i , j -> i j'
,
t
,
self
.
inv_freq
.
to
(
dtype
=
dtype
,
device
=
device
))
freqs
=
torch
.
einsum
(
'i , j -> i j'
,
t
,
comfy
.
ops
.
cast_to_input
(
self
.
inv_freq
,
t
))
freqs
=
torch
.
cat
((
freqs
,
freqs
),
dim
=
-
1
)
freqs
=
torch
.
cat
((
freqs
,
freqs
),
dim
=
-
1
)
if
self
.
scale
is
None
:
if
self
.
scale
is
None
:
return
freqs
,
1.
return
freqs
,
1.
power
=
(
torch
.
arange
(
seq_len
,
device
=
device
)
-
(
seq_len
//
2
))
/
self
.
scale_base
power
=
(
torch
.
arange
(
seq_len
,
device
=
device
)
-
(
seq_len
//
2
))
/
self
.
scale_base
scale
=
self
.
scale
.
to
(
dtype
=
dtype
,
device
=
device
)
**
rearrange
(
power
,
'n -> n 1'
)
scale
=
comfy
.
ops
.
cast_to_input
(
self
.
scale
,
t
)
**
rearrange
(
power
,
'n -> n 1'
)
scale
=
torch
.
cat
((
scale
,
scale
),
dim
=
-
1
)
scale
=
torch
.
cat
((
scale
,
scale
),
dim
=
-
1
)
return
freqs
,
scale
return
freqs
,
scale
...
@@ -568,7 +571,7 @@ class ContinuousTransformer(nn.Module):
...
@@ -568,7 +571,7 @@ class ContinuousTransformer(nn.Module):
self
.
project_out
=
operations
.
Linear
(
dim
,
dim_out
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
if
dim_out
is
not
None
else
nn
.
Identity
()
self
.
project_out
=
operations
.
Linear
(
dim
,
dim_out
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
if
dim_out
is
not
None
else
nn
.
Identity
()
if
rotary_pos_emb
:
if
rotary_pos_emb
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
max
(
dim_heads
//
2
,
32
))
self
.
rotary_pos_emb
=
RotaryEmbedding
(
max
(
dim_heads
//
2
,
32
)
,
device
=
device
,
dtype
=
dtype
)
else
:
else
:
self
.
rotary_pos_emb
=
None
self
.
rotary_pos_emb
=
None
...
...
comfy/ldm/aura/mmdit.py
View file @
25853d0b
...
@@ -8,6 +8,7 @@ import torch.nn as nn
...
@@ -8,6 +8,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
comfy.ldm.modules.attention
import
optimized_attention
from
comfy.ldm.modules.attention
import
optimized_attention
import
comfy.ops
def
modulate
(
x
,
shift
,
scale
):
def
modulate
(
x
,
shift
,
scale
):
return
x
*
(
1
+
scale
.
unsqueeze
(
1
))
+
shift
.
unsqueeze
(
1
)
return
x
*
(
1
+
scale
.
unsqueeze
(
1
))
+
shift
.
unsqueeze
(
1
)
...
@@ -427,7 +428,7 @@ class MMDiT(nn.Module):
...
@@ -427,7 +428,7 @@ class MMDiT(nn.Module):
max_dim
=
max
(
h
,
w
)
max_dim
=
max
(
h
,
w
)
cur_dim
=
self
.
h_max
cur_dim
=
self
.
h_max
pos_encoding
=
self
.
positional_encoding
.
reshape
(
1
,
cur_dim
,
cur_dim
,
-
1
)
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
pos_encoding
=
comfy
.
ops
.
cast_to_input
(
self
.
positional_encoding
.
reshape
(
1
,
cur_dim
,
cur_dim
,
-
1
)
,
x
)
if
max_dim
>
cur_dim
:
if
max_dim
>
cur_dim
:
pos_encoding
=
F
.
interpolate
(
pos_encoding
.
movedim
(
-
1
,
1
),
(
max_dim
,
max_dim
),
mode
=
"bilinear"
).
movedim
(
1
,
-
1
)
pos_encoding
=
F
.
interpolate
(
pos_encoding
.
movedim
(
-
1
,
1
),
(
max_dim
,
max_dim
),
mode
=
"bilinear"
).
movedim
(
1
,
-
1
)
...
@@ -455,7 +456,7 @@ class MMDiT(nn.Module):
...
@@ -455,7 +456,7 @@ class MMDiT(nn.Module):
t
=
timestep
t
=
timestep
c
=
self
.
cond_seq_linear
(
c_seq
)
# B, T_c, D
c
=
self
.
cond_seq_linear
(
c_seq
)
# B, T_c, D
c
=
torch
.
cat
([
self
.
register_tokens
.
to
(
device
=
c
.
device
,
dtype
=
c
.
dtype
).
repeat
(
c
.
size
(
0
),
1
,
1
),
c
],
dim
=
1
)
c
=
torch
.
cat
([
comfy
.
ops
.
cast_to_input
(
self
.
register_tokens
,
c
).
repeat
(
c
.
size
(
0
),
1
,
1
),
c
],
dim
=
1
)
global_cond
=
self
.
t_embedder
(
t
,
x
.
dtype
)
# B, D
global_cond
=
self
.
t_embedder
(
t
,
x
.
dtype
)
# B, D
...
...
comfy/ldm/cascade/common.py
View file @
25853d0b
...
@@ -19,14 +19,7 @@
...
@@ -19,14 +19,7 @@
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
comfy.ldm.modules.attention
import
optimized_attention
from
comfy.ldm.modules.attention
import
optimized_attention
import
comfy.ops
class
Linear
(
torch
.
nn
.
Linear
):
def
reset_parameters
(
self
):
return
None
class
Conv2d
(
torch
.
nn
.
Conv2d
):
def
reset_parameters
(
self
):
return
None
class
OptimizedAttention
(
nn
.
Module
):
class
OptimizedAttention
(
nn
.
Module
):
def
__init__
(
self
,
c
,
nhead
,
dropout
=
0.0
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
def
__init__
(
self
,
c
,
nhead
,
dropout
=
0.0
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
...
@@ -78,13 +71,13 @@ class GlobalResponseNorm(nn.Module):
...
@@ -78,13 +71,13 @@ class GlobalResponseNorm(nn.Module):
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
def
__init__
(
self
,
dim
,
dtype
=
None
,
device
=
None
):
def
__init__
(
self
,
dim
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
gamma
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
1
,
dim
,
dtype
=
dtype
,
device
=
device
))
self
.
gamma
=
nn
.
Parameter
(
torch
.
empty
(
1
,
1
,
1
,
dim
,
dtype
=
dtype
,
device
=
device
))
self
.
beta
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
1
,
dim
,
dtype
=
dtype
,
device
=
device
))
self
.
beta
=
nn
.
Parameter
(
torch
.
empty
(
1
,
1
,
1
,
dim
,
dtype
=
dtype
,
device
=
device
))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
Gx
=
torch
.
norm
(
x
,
p
=
2
,
dim
=
(
1
,
2
),
keepdim
=
True
)
Gx
=
torch
.
norm
(
x
,
p
=
2
,
dim
=
(
1
,
2
),
keepdim
=
True
)
Nx
=
Gx
/
(
Gx
.
mean
(
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
Nx
=
Gx
/
(
Gx
.
mean
(
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
return
self
.
gamma
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
*
(
x
*
Nx
)
+
self
.
beta
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
+
x
return
comfy
.
ops
.
cast_to_input
(
self
.
gamma
,
x
)
*
(
x
*
Nx
)
+
comfy
.
ops
.
cast_to_input
(
self
.
beta
,
x
)
+
x
class
ResBlock
(
nn
.
Module
):
class
ResBlock
(
nn
.
Module
):
...
...
comfy/ldm/hydit/models.py
View file @
25853d0b
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
comfy.ops
from
comfy.ldm.modules.diffusionmodules.mmdit
import
Mlp
,
TimestepEmbedder
,
PatchEmbed
,
RMSNorm
from
comfy.ldm.modules.diffusionmodules.mmdit
import
Mlp
,
TimestepEmbedder
,
PatchEmbed
,
RMSNorm
from
comfy.ldm.modules.diffusionmodules.util
import
timestep_embedding
from
comfy.ldm.modules.diffusionmodules.util
import
timestep_embedding
from
torch.utils
import
checkpoint
from
torch.utils
import
checkpoint
...
@@ -234,7 +235,7 @@ class HunYuanDiT(nn.Module):
...
@@ -234,7 +235,7 @@ class HunYuanDiT(nn.Module):
if
self
.
use_style_cond
:
if
self
.
use_style_cond
:
# Here we use a default learned embedder layer for future extension.
# Here we use a default learned embedder layer for future extension.
self
.
style_embedder
=
nn
.
Embedding
(
1
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
self
.
style_embedder
=
operations
.
Embedding
(
1
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
self
.
extra_in_dim
+=
hidden_size
self
.
extra_in_dim
+=
hidden_size
# Text embedding for `add`
# Text embedding for `add`
...
@@ -321,7 +322,7 @@ class HunYuanDiT(nn.Module):
...
@@ -321,7 +322,7 @@ class HunYuanDiT(nn.Module):
b_t5
,
l_t5
,
c_t5
=
text_states_t5
.
shape
b_t5
,
l_t5
,
c_t5
=
text_states_t5
.
shape
text_states_t5
=
self
.
mlp_t5
(
text_states_t5
.
view
(
-
1
,
c_t5
)).
view
(
b_t5
,
l_t5
,
-
1
)
text_states_t5
=
self
.
mlp_t5
(
text_states_t5
.
view
(
-
1
,
c_t5
)).
view
(
b_t5
,
l_t5
,
-
1
)
padding
=
self
.
text_embedding_padding
.
to
(
text_states
)
padding
=
comfy
.
ops
.
cast_to_input
(
self
.
text_embedding_padding
,
text_states
)
text_states
[:,
-
self
.
text_len
:]
=
torch
.
where
(
text_states_mask
[:,
-
self
.
text_len
:].
unsqueeze
(
2
),
text_states
[:,
-
self
.
text_len
:],
padding
[:
self
.
text_len
])
text_states
[:,
-
self
.
text_len
:]
=
torch
.
where
(
text_states_mask
[:,
-
self
.
text_len
:].
unsqueeze
(
2
),
text_states
[:,
-
self
.
text_len
:],
padding
[:
self
.
text_len
])
text_states_t5
[:,
-
self
.
text_len_t5
:]
=
torch
.
where
(
text_states_t5_mask
[:,
-
self
.
text_len_t5
:].
unsqueeze
(
2
),
text_states_t5
[:,
-
self
.
text_len_t5
:],
padding
[
self
.
text_len
:])
text_states_t5
[:,
-
self
.
text_len_t5
:]
=
torch
.
where
(
text_states_t5_mask
[:,
-
self
.
text_len_t5
:].
unsqueeze
(
2
),
text_states_t5
[:,
-
self
.
text_len_t5
:],
padding
[
self
.
text_len
:])
...
...
comfy/ldm/hydit/poolers.py
View file @
25853d0b
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
comfy.ldm.modules.attention
import
optimized_attention
#TODO
from
comfy.ldm.modules.attention
import
optimized_attention
import
comfy.ops
class
AttentionPool
(
nn
.
Module
):
class
AttentionPool
(
nn
.
Module
):
def
__init__
(
self
,
spacial_dim
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
output_dim
:
int
=
None
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
def
__init__
(
self
,
spacial_dim
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
output_dim
:
int
=
None
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
...
@@ -19,7 +19,7 @@ class AttentionPool(nn.Module):
...
@@ -19,7 +19,7 @@ class AttentionPool(nn.Module):
x
=
x
[:,:
self
.
positional_embedding
.
shape
[
0
]
-
1
]
x
=
x
[:,:
self
.
positional_embedding
.
shape
[
0
]
-
1
]
x
=
x
.
permute
(
1
,
0
,
2
)
# NLC -> LNC
x
=
x
.
permute
(
1
,
0
,
2
)
# NLC -> LNC
x
=
torch
.
cat
([
x
.
mean
(
dim
=
0
,
keepdim
=
True
),
x
],
dim
=
0
)
# (L+1)NC
x
=
torch
.
cat
([
x
.
mean
(
dim
=
0
,
keepdim
=
True
),
x
],
dim
=
0
)
# (L+1)NC
x
=
x
+
self
.
positional_embedding
[:,
None
,
:]
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
)
# (L+1)NC
x
=
x
+
comfy
.
ops
.
cast_to_input
(
self
.
positional_embedding
[:,
None
,
:]
,
x
)
# (L+1)NC
q
=
self
.
q_proj
(
x
[:
1
])
q
=
self
.
q_proj
(
x
[:
1
])
k
=
self
.
k_proj
(
x
)
k
=
self
.
k_proj
(
x
)
...
...
comfy/ldm/modules/diffusionmodules/mmdit.py
View file @
25853d0b
...
@@ -8,6 +8,7 @@ import torch.nn as nn
...
@@ -8,6 +8,7 @@ import torch.nn as nn
from
..
import
attention
from
..
import
attention
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
.util
import
timestep_embedding
from
.util
import
timestep_embedding
import
comfy.ops
def
default
(
x
,
y
):
def
default
(
x
,
y
):
if
x
is
not
None
:
if
x
is
not
None
:
...
@@ -926,7 +927,7 @@ class MMDiT(nn.Module):
...
@@ -926,7 +927,7 @@ class MMDiT(nn.Module):
context
=
self
.
context_processor
(
context
)
context
=
self
.
context_processor
(
context
)
hw
=
x
.
shape
[
-
2
:]
hw
=
x
.
shape
[
-
2
:]
x
=
self
.
x_embedder
(
x
)
+
self
.
cropped_pos_embed
(
hw
,
device
=
x
.
device
)
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x
=
self
.
x_embedder
(
x
)
+
comfy
.
ops
.
cast_to_input
(
self
.
cropped_pos_embed
(
hw
,
device
=
x
.
device
)
,
x
)
c
=
self
.
t_embedder
(
t
,
dtype
=
x
.
dtype
)
# (N, D)
c
=
self
.
t_embedder
(
t
,
dtype
=
x
.
dtype
)
# (N, D)
if
y
is
not
None
and
self
.
y_embedder
is
not
None
:
if
y
is
not
None
and
self
.
y_embedder
is
not
None
:
y
=
self
.
y_embedder
(
y
)
# (N, D)
y
=
self
.
y_embedder
(
y
)
# (N, D)
...
...
comfy/ops.py
View file @
25853d0b
...
@@ -19,14 +19,17 @@
...
@@ -19,14 +19,17 @@
import
torch
import
torch
import
comfy.model_management
import
comfy.model_management
def
cast_to_input
(
weight
,
input
,
non_blocking
=
False
):
return
weight
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
def
cast_bias_weight
(
s
,
input
):
def
cast_bias_weight
(
s
,
input
):
bias
=
None
bias
=
None
non_blocking
=
comfy
.
model_management
.
device_should_use_non_blocking
(
input
.
device
)
non_blocking
=
comfy
.
model_management
.
device_should_use_non_blocking
(
input
.
device
)
if
s
.
bias
is
not
None
:
if
s
.
bias
is
not
None
:
bias
=
s
.
bias
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
bias
=
cast_to_input
(
s
.
bias
,
input
,
non_blocking
=
non_blocking
)
if
s
.
bias_function
is
not
None
:
if
s
.
bias_function
is
not
None
:
bias
=
s
.
bias_function
(
bias
)
bias
=
s
.
bias_function
(
bias
)
weight
=
s
.
weight
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
weight
=
cast_to_input
(
s
.
weight
,
input
,
non_blocking
=
non_blocking
)
if
s
.
weight_function
is
not
None
:
if
s
.
weight_function
is
not
None
:
weight
=
s
.
weight_function
(
weight
)
weight
=
s
.
weight_function
(
weight
)
return
weight
,
bias
return
weight
,
bias
...
@@ -168,6 +171,21 @@ class disable_weight_init:
...
@@ -168,6 +171,21 @@ class disable_weight_init:
else
:
else
:
return
super
().
forward
(
*
args
,
**
kwargs
)
return
super
().
forward
(
*
args
,
**
kwargs
)
class
Embedding
(
torch
.
nn
.
Embedding
,
CastWeightBiasOp
):
def
reset_parameters
(
self
):
self
.
bias
=
None
return
None
def
forward_comfy_cast_weights
(
self
,
input
):
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
return
torch
.
nn
.
functional
.
embedding
(
input
,
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
comfy_cast_weights
:
return
self
.
forward_comfy_cast_weights
(
*
args
,
**
kwargs
)
else
:
return
super
().
forward
(
*
args
,
**
kwargs
)
@
classmethod
@
classmethod
def
conv_nd
(
s
,
dims
,
*
args
,
**
kwargs
):
def
conv_nd
(
s
,
dims
,
*
args
,
**
kwargs
):
if
dims
==
2
:
if
dims
==
2
:
...
@@ -202,3 +220,6 @@ class manual_cast(disable_weight_init):
...
@@ -202,3 +220,6 @@ class manual_cast(disable_weight_init):
class
ConvTranspose1d
(
disable_weight_init
.
ConvTranspose1d
):
class
ConvTranspose1d
(
disable_weight_init
.
ConvTranspose1d
):
comfy_cast_weights
=
True
comfy_cast_weights
=
True
class
Embedding
(
disable_weight_init
.
Embedding
):
comfy_cast_weights
=
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment