Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
7b030a7d
Unverified
Commit
7b030a7d
authored
Nov 03, 2022
by
Suraj Patil
Committed by
GitHub
Nov 03, 2022
Browse files
handle device for randn in euler step (#1124)
* handle device for randn in euler step * convert device to str
parent
42bb4594
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
2 deletions
+20
-2
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
...ffusers/schedulers/scheduling_euler_ancestral_discrete.py
+10
-1
src/diffusers/schedulers/scheduling_euler_discrete.py
src/diffusers/schedulers/scheduling_euler_discrete.py
+10
-1
No files found.
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
View file @
7b030a7d
...
...
@@ -217,7 +217,16 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample
=
sample
+
derivative
*
dt
device
=
model_output
.
device
if
torch
.
is_tensor
(
model_output
)
else
"cpu"
noise
=
torch
.
randn
(
model_output
.
shape
,
dtype
=
model_output
.
dtype
,
generator
=
generator
).
to
(
device
)
if
str
(
device
)
==
"mps"
:
# randn does not work reproducibly on mps
noise
=
torch
.
randn
(
model_output
.
shape
,
dtype
=
model_output
.
dtype
,
device
=
"cpu"
,
generator
=
generator
).
to
(
device
)
else
:
noise
=
torch
.
randn
(
model_output
.
shape
,
dtype
=
model_output
.
dtype
,
device
=
device
,
generator
=
generator
).
to
(
device
)
prev_sample
=
prev_sample
+
noise
*
sigma_up
if
not
return_dict
:
...
...
src/diffusers/schedulers/scheduling_euler_discrete.py
View file @
7b030a7d
...
...
@@ -214,7 +214,16 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
gamma
=
min
(
s_churn
/
(
len
(
self
.
sigmas
)
-
1
),
2
**
0.5
-
1
)
if
s_tmin
<=
sigma
<=
s_tmax
else
0.0
device
=
model_output
.
device
if
torch
.
is_tensor
(
model_output
)
else
"cpu"
noise
=
torch
.
randn
(
model_output
.
shape
,
dtype
=
model_output
.
dtype
,
generator
=
generator
).
to
(
device
)
if
str
(
device
)
==
"mps"
:
# randn does not work reproducibly on mps
noise
=
torch
.
randn
(
model_output
.
shape
,
dtype
=
model_output
.
dtype
,
device
=
"cpu"
,
generator
=
generator
).
to
(
device
)
else
:
noise
=
torch
.
randn
(
model_output
.
shape
,
dtype
=
model_output
.
dtype
,
device
=
device
,
generator
=
generator
).
to
(
device
)
eps
=
noise
*
s_noise
sigma_hat
=
sigma
*
(
gamma
+
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment