Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
10b43cee
Commit
10b43cee
authored
Jul 24, 2024
by
comfyanonymous
Browse files
Remove duplicate code.
parent
0a4c49c5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
27 deletions
+2
-27
comfy/ldm/modules/diffusionmodules/mmdit.py
comfy/ldm/modules/diffusionmodules/mmdit.py
+2
-27
No files found.
comfy/ldm/modules/diffusionmodules/mmdit.py
View file @
10b43cee
...
@@ -7,6 +7,7 @@ import torch
...
@@ -7,6 +7,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
..
import
attention
from
..
import
attention
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
.util
import
timestep_embedding
def
default
(
x
,
y
):
def
default
(
x
,
y
):
if
x
is
not
None
:
if
x
is
not
None
:
...
@@ -230,34 +231,8 @@ class TimestepEmbedder(nn.Module):
...
@@ -230,34 +231,8 @@ class TimestepEmbedder(nn.Module):
)
)
self
.
frequency_embedding_size
=
frequency_embedding_size
self
.
frequency_embedding_size
=
frequency_embedding_size
@
staticmethod
def
timestep_embedding
(
t
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
,
device
=
t
.
device
)
/
half
)
args
=
t
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
(
[
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
if
torch
.
is_floating_point
(
t
):
embedding
=
embedding
.
to
(
dtype
=
t
.
dtype
)
return
embedding
def
forward
(
self
,
t
,
dtype
,
**
kwargs
):
def
forward
(
self
,
t
,
dtype
,
**
kwargs
):
t_freq
=
self
.
timestep_embedding
(
t
,
self
.
frequency_embedding_size
).
to
(
dtype
)
t_freq
=
timestep_embedding
(
t
,
self
.
frequency_embedding_size
).
to
(
dtype
)
t_emb
=
self
.
mlp
(
t_freq
)
t_emb
=
self
.
mlp
(
t_freq
)
return
t_emb
return
t_emb
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment