Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
73ce1780
Unverified
Commit
73ce1780
authored
Jun 11, 2024
by
Dango233
Committed by
GitHub
Jun 11, 2024
Browse files
Remove redundancy in mmdit.py (#3685)
parent
4134564d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
61 deletions
+0
-61
comfy/ldm/modules/diffusionmodules/mmdit.py
comfy/ldm/modules/diffusionmodules/mmdit.py
+0
-61
No files found.
comfy/ldm/modules/diffusionmodules/mmdit.py
View file @
73ce1780
...
...
@@ -835,72 +835,11 @@ class MMDiT(nn.Module):
)
self
.
final_layer
=
FinalLayer
(
self
.
hidden_size
,
patch_size
,
self
.
out_channels
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
# self.initialize_weights()
if
compile_core
:
assert
False
self
.
forward_core_with_concat
=
torch
.
compile
(
self
.
forward_core_with_concat
)
def
initialize_weights
(
self
):
# TODO: Init context_embedder?
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Initialize (and freeze) pos_embed by sin-cos embedding
if
self
.
pos_embed
is
not
None
:
pos_embed_grid_size
=
(
int
(
self
.
x_embedder
.
num_patches
**
0.5
)
if
self
.
pos_embed_max_size
is
None
else
self
.
pos_embed_max_size
)
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
pos_embed
.
shape
[
-
1
],
int
(
self
.
x_embedder
.
num_patches
**
0.5
),
pos_embed_grid_size
,
scaling_factor
=
self
.
pos_embed_scaling_factor
,
offset
=
self
.
pos_embed_offset
,
)
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
pos_embed
.
shape
[
-
1
],
int
(
self
.
pos_embed
.
shape
[
-
2
]
**
0.5
),
scaling_factor
=
self
.
pos_embed_scaling_factor
,
)
self
.
pos_embed
.
data
.
copy_
(
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w
=
self
.
x_embedder
.
proj
.
weight
.
data
nn
.
init
.
xavier_uniform_
(
w
.
view
([
w
.
shape
[
0
],
-
1
]))
nn
.
init
.
constant_
(
self
.
x_embedder
.
proj
.
bias
,
0
)
if
hasattr
(
self
,
"y_embedder"
):
nn
.
init
.
normal_
(
self
.
y_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
y_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in DiT blocks:
for
block
in
self
.
joint_blocks
:
nn
.
init
.
constant_
(
block
.
x_block
.
adaLN_modulation
[
-
1
].
weight
,
0
)
nn
.
init
.
constant_
(
block
.
x_block
.
adaLN_modulation
[
-
1
].
bias
,
0
)
nn
.
init
.
constant_
(
block
.
context_block
.
adaLN_modulation
[
-
1
].
weight
,
0
)
nn
.
init
.
constant_
(
block
.
context_block
.
adaLN_modulation
[
-
1
].
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
final_layer
.
adaLN_modulation
[
-
1
].
weight
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
adaLN_modulation
[
-
1
].
bias
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
bias
,
0
)
def
cropped_pos_embed
(
self
,
hw
,
device
=
None
):
p
=
self
.
x_embedder
.
patch_size
[
0
]
h
,
w
=
hw
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment