Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
cc45831e
Commit
cc45831e
authored
Jun 16, 2022
by
patil-suraj
Browse files
add GradTTSScheduler
parent
2d8d82f9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
1 deletion
+54
-1
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-0
src/diffusers/schedulers/scheduling_grad_tts.py
src/diffusers/schedulers/scheduling_grad_tts.py
+52
-0
No files found.
src/diffusers/__init__.py
View file @
cc45831e
...
@@ -11,5 +11,5 @@ from .models.unet_ldm import UNetLDMModel
...
@@ -11,5 +11,5 @@ from .models.unet_ldm import UNetLDMModel
from
.models.unet_grad_tts
import
UNetGradTTSModel
from
.models.unet_grad_tts
import
UNetGradTTSModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
,
PNDM
,
BDDM
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
,
PNDM
,
BDDM
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
SchedulerMixin
,
PNDMScheduler
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
SchedulerMixin
,
PNDMScheduler
,
GradTTSScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
src/diffusers/schedulers/__init__.py
View file @
cc45831e
...
@@ -20,4 +20,5 @@ from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
...
@@ -20,4 +20,5 @@ from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from
.scheduling_ddim
import
DDIMScheduler
from
.scheduling_ddim
import
DDIMScheduler
from
.scheduling_ddpm
import
DDPMScheduler
from
.scheduling_ddpm
import
DDPMScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
src/diffusers/schedulers/scheduling_grad_tts.py
0 → 100644
View file @
cc45831e
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numpy
as
np
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
class
GradTTSScheduler
(
SchedulerMixin
,
ConfigMixin
):
def
__init__
(
self
,
timesteps
=
1000
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
tensor_format
=
"np"
,
):
super
().
__init__
()
self
.
register
(
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
)
self
.
timesteps
=
int
(
timesteps
)
self
.
set_format
(
tensor_format
=
tensor_format
)
def
sample_noise
(
self
,
timestep
):
noise
=
self
.
beta_start
+
(
self
.
beta_end
-
self
.
beta_start
)
*
timestep
return
noise
def
step
(
self
,
xt
,
residual
,
mu
,
h
,
timestep
):
noise_t
=
self
.
sample_noise
(
timestep
)
dxt
=
0.5
*
(
mu
-
xt
-
residual
)
dxt
=
dxt
*
noise_t
*
h
xt
=
xt
-
dxt
return
xt
def
__len__
(
self
):
return
self
.
timesteps
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment