Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
08e62fe0
Unverified
Commit
08e62fe0
authored
Jan 16, 2025
by
hlky
Committed by
GitHub
Jan 16, 2025
Browse files
Scheduling fixes on MPS (#10549)
* use np.int32 in scheduling * test_add_noise_device * -np.int32, fixes
parent
9e1b8a00
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
5 additions
and
5 deletions
+5
-5
src/diffusers/schedulers/scheduling_heun_discrete.py
src/diffusers/schedulers/scheduling_heun_discrete.py
+1
-1
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+1
-1
tests/schedulers/test_scheduler_lcm.py
tests/schedulers/test_scheduler_lcm.py
+1
-1
tests/schedulers/test_schedulers.py
tests/schedulers/test_schedulers.py
+2
-2
No files found.
src/diffusers/schedulers/scheduling_heun_discrete.py
View file @
08e62fe0
...
...
@@ -342,7 +342,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps
=
torch
.
from_numpy
(
timesteps
)
timesteps
=
torch
.
cat
([
timesteps
[:
1
],
timesteps
[
1
:].
repeat_interleave
(
2
)])
self
.
timesteps
=
timesteps
.
to
(
device
=
device
)
self
.
timesteps
=
timesteps
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
# empty dt and derivative
self
.
prev_derivative
=
None
...
...
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
08e62fe0
...
...
@@ -311,7 +311,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas
=
np
.
concatenate
([
sigmas
,
[
0.0
]]).
astype
(
np
.
float32
)
self
.
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
device
=
device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
,
dtype
=
torch
.
float32
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
...
...
tests/schedulers/test_scheduler_lcm.py
View file @
08e62fe0
...
...
@@ -99,7 +99,7 @@ class LCMSchedulerTest(SchedulerCommonTest):
scaled_sample
=
scheduler
.
scale_model_input
(
sample
,
0.0
)
self
.
assertEqual
(
sample
.
shape
,
scaled_sample
.
shape
)
noise
=
torch
.
randn
_like
(
scaled_sample
).
to
(
torch_device
)
noise
=
torch
.
randn
(
scaled_sample
.
shape
).
to
(
torch_device
)
t
=
scheduler
.
timesteps
[
5
][
None
]
noised
=
scheduler
.
add_noise
(
scaled_sample
,
noise
,
t
)
self
.
assertEqual
(
noised
.
shape
,
scaled_sample
.
shape
)
...
...
tests/schedulers/test_schedulers.py
View file @
08e62fe0
...
...
@@ -361,7 +361,7 @@ class SchedulerCommonTest(unittest.TestCase):
if
isinstance
(
t
,
torch
.
Tensor
):
num_dims
=
len
(
sample
.
shape
)
# pad t with 1s to match num_dims
t
=
t
.
reshape
(
-
1
,
*
(
1
,)
*
(
num_dims
-
1
)).
to
(
sample
.
device
).
to
(
sample
.
dtype
)
t
=
t
.
reshape
(
-
1
,
*
(
1
,)
*
(
num_dims
-
1
)).
to
(
sample
.
device
,
dtype
=
sample
.
dtype
)
return
sample
*
t
/
(
t
+
1
)
...
...
@@ -722,7 +722,7 @@ class SchedulerCommonTest(unittest.TestCase):
scaled_sample
=
scheduler
.
scale_model_input
(
sample
,
0.0
)
self
.
assertEqual
(
sample
.
shape
,
scaled_sample
.
shape
)
noise
=
torch
.
randn
_like
(
scaled_sample
).
to
(
torch_device
)
noise
=
torch
.
randn
(
scaled_sample
.
shape
).
to
(
torch_device
)
t
=
scheduler
.
timesteps
[
5
][
None
]
noised
=
scheduler
.
add_noise
(
scaled_sample
,
noise
,
t
)
self
.
assertEqual
(
noised
.
shape
,
scaled_sample
.
shape
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment