Commit cc45831e authored by patil-suraj's avatar patil-suraj
Browse files

add GradTTSScheduler

parent 2d8d82f9
...@@ -11,5 +11,5 @@ from .models.unet_ldm import UNetLDMModel ...@@ -11,5 +11,5 @@ from .models.unet_ldm import UNetLDMModel
from .models.unet_grad_tts import UNetGradTTSModel from .models.unet_grad_tts import UNetGradTTSModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler, GradTTSScheduler
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
...@@ -20,4 +20,5 @@ from .classifier_free_guidance import ClassifierFreeGuidanceScheduler ...@@ -20,4 +20,5 @@ from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .scheduling_ddim import DDIMScheduler from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm import DDPMScheduler
from .scheduling_pndm import PNDMScheduler from .scheduling_pndm import PNDMScheduler
from .scheduling_grad_tts import GradTTSScheduler
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
class GradTTSScheduler(SchedulerMixin, ConfigMixin):
def __init__(
self,
timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
tensor_format="np",
):
super().__init__()
self.register(
timesteps=timesteps,
beta_start=beta_start,
beta_end=beta_end,
)
self.timesteps = int(timesteps)
self.set_format(tensor_format=tensor_format)
def sample_noise(self, timestep):
noise = self.beta_start + (self.beta_end - self.beta_start) * timestep
return noise
def step(self, xt, residual, mu, h, timestep):
noise_t = self.sample_noise(timestep)
dxt = 0.5 * (mu - xt - residual)
dxt = dxt * noise_t * h
xt = xt - dxt
return xt
def __len__(self):
return self.timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment