Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7b030a7d
Unverified
Commit
7b030a7d
authored
Nov 03, 2022
by
Suraj Patil
Committed by
GitHub
Nov 03, 2022
Browse files
handle device for randn in euler step (#1124)
* handle device for randn in euler step * convert device to str
parent
42bb4594
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
2 deletions
+20
-2
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
...ffusers/schedulers/scheduling_euler_ancestral_discrete.py
+10
-1
src/diffusers/schedulers/scheduling_euler_discrete.py
src/diffusers/schedulers/scheduling_euler_discrete.py
+10
-1
No files found.
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
View file @
7b030a7d
...
@@ -217,7 +217,16 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -217,7 +217,16 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample
=
sample
+
derivative
*
dt
prev_sample
=
sample
+
derivative
*
dt
device
=
model_output
.
device
if
torch
.
is_tensor
(
model_output
)
else
"cpu"
device
=
model_output
.
device
if
torch
.
is_tensor
(
model_output
)
else
"cpu"
noise
=
torch
.
randn
(
model_output
.
shape
,
dtype
=
model_output
.
dtype
,
generator
=
generator
).
to
(
device
)
if
str
(
device
)
==
"mps"
:
# randn does not work reproducibly on mps
noise
=
torch
.
randn
(
model_output
.
shape
,
dtype
=
model_output
.
dtype
,
device
=
"cpu"
,
generator
=
generator
).
to
(
device
)
else
:
noise
=
torch
.
randn
(
model_output
.
shape
,
dtype
=
model_output
.
dtype
,
device
=
device
,
generator
=
generator
).
to
(
device
)
prev_sample
=
prev_sample
+
noise
*
sigma_up
prev_sample
=
prev_sample
+
noise
*
sigma_up
if
not
return_dict
:
if
not
return_dict
:
...
...
src/diffusers/schedulers/scheduling_euler_discrete.py
View file @
7b030a7d
...
@@ -214,7 +214,16 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -214,7 +214,16 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
gamma
=
min
(
s_churn
/
(
len
(
self
.
sigmas
)
-
1
),
2
**
0.5
-
1
)
if
s_tmin
<=
sigma
<=
s_tmax
else
0.0
gamma
=
min
(
s_churn
/
(
len
(
self
.
sigmas
)
-
1
),
2
**
0.5
-
1
)
if
s_tmin
<=
sigma
<=
s_tmax
else
0.0
device
=
model_output
.
device
if
torch
.
is_tensor
(
model_output
)
else
"cpu"
device
=
model_output
.
device
if
torch
.
is_tensor
(
model_output
)
else
"cpu"
noise
=
torch
.
randn
(
model_output
.
shape
,
dtype
=
model_output
.
dtype
,
generator
=
generator
).
to
(
device
)
if
str
(
device
)
==
"mps"
:
# randn does not work reproducibly on mps
noise
=
torch
.
randn
(
model_output
.
shape
,
dtype
=
model_output
.
dtype
,
device
=
"cpu"
,
generator
=
generator
).
to
(
device
)
else
:
noise
=
torch
.
randn
(
model_output
.
shape
,
dtype
=
model_output
.
dtype
,
device
=
device
,
generator
=
generator
).
to
(
device
)
eps
=
noise
*
s_noise
eps
=
noise
*
s_noise
sigma_hat
=
sigma
*
(
gamma
+
1
)
sigma_hat
=
sigma
*
(
gamma
+
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment