Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8c1f5197
Commit
8c1f5197
authored
Jun 17, 2022
by
Patrick von Platen
Browse files
make clip name shorter
parent
dcb23b2d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
12 deletions
+12
-12
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+2
-2
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+4
-4
tests/test_scheduler.py
tests/test_scheduler.py
+6
-6
No files found.
src/diffusers/schedulers/scheduling_ddim.py
View file @
8c1f5197
...
...
@@ -28,7 +28,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule
=
"linear"
,
trained_betas
=
None
,
timestep_values
=
None
,
clip_
predicted_
sample
=
True
,
clip_sample
=
True
,
tensor_format
=
"np"
,
):
super
().
__init__
()
...
...
@@ -40,7 +40,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
)
self
.
timesteps
=
int
(
timesteps
)
self
.
timestep_values
=
timestep_values
# save the fixed timestep values for BDDM
self
.
clip_sample
=
clip_
predicted_
sample
self
.
clip_sample
=
clip_sample
if
beta_schedule
==
"linear"
:
self
.
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
8c1f5197
...
...
@@ -29,7 +29,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas
=
None
,
timestep_values
=
None
,
variance_type
=
"fixed_small"
,
clip_
predicted_
sample
=
True
,
clip_sample
=
True
,
tensor_format
=
"np"
,
):
super
().
__init__
()
...
...
@@ -41,11 +41,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas
=
trained_betas
,
timestep_values
=
timestep_values
,
variance_type
=
variance_type
,
clip_
predicted_
sample
=
clip_
predicted_
sample
,
clip_sample
=
clip_sample
,
)
self
.
timesteps
=
int
(
timesteps
)
self
.
timestep_values
=
timestep_values
# save the fixed timestep values for BDDM
self
.
clip_sample
=
clip_
predicted_
sample
self
.
clip_sample
=
clip_sample
self
.
variance_type
=
variance_type
if
trained_betas
is
not
None
:
...
...
@@ -124,7 +124,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
pred_original_sample
=
(
sample
-
beta_prod_t
**
(
0.5
)
*
residual
)
/
alpha_prod_t
**
(
0.5
)
# 3. Clip "predicted x_0"
if
self
.
clip_
predicted_
sample
:
if
self
.
clip_sample
:
pred_original_sample
=
self
.
clip
(
pred_original_sample
,
-
1
,
1
)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
...
...
tests/test_scheduler.py
View file @
8c1f5197
...
...
@@ -172,7 +172,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
"variance_type"
:
"fixed_small"
,
"clip_
predicted_
sample"
:
True
,
"clip_sample"
:
True
,
}
config
.
update
(
**
kwargs
)
...
...
@@ -195,8 +195,8 @@ class DDPMSchedulerTest(SchedulerCommonTest):
self
.
check_over_configs
(
variance_type
=
variance
)
def
test_clip_image
(
self
):
for
clip_
predicted_
sample
in
[
True
,
False
]:
self
.
check_over_configs
(
clip_
predicted_
sample
=
clip_
predicted_
sample
)
for
clip_sample
in
[
True
,
False
]:
self
.
check_over_configs
(
clip_sample
=
clip_sample
)
def
test_time_indices
(
self
):
for
t
in
[
0
,
500
,
999
]:
...
...
@@ -251,7 +251,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
"clip_
predicted_
sample"
:
True
,
"clip_sample"
:
True
,
}
config
.
update
(
**
kwargs
)
...
...
@@ -270,8 +270,8 @@ class DDIMSchedulerTest(SchedulerCommonTest):
self
.
check_over_configs
(
beta_schedule
=
schedule
)
def
test_clip_image
(
self
):
for
clip_
predicted_
sample
in
[
True
,
False
]:
self
.
check_over_configs
(
clip_
predicted_
sample
=
clip_
predicted_
sample
)
for
clip_sample
in
[
True
,
False
]:
self
.
check_over_configs
(
clip_sample
=
clip_sample
)
def
test_time_indices
(
self
):
for
t
in
[
1
,
10
,
49
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment