Commit ac796924 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add score estimation model

parent bd9c9fbf
...@@ -7,9 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode ...@@ -7,9 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__ = "0.0.4" __version__ = "0.0.4"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.unet import UNetModel from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel
from .models.unet_ldm import UNetLDMModel
from .models.unet_rl import TemporalUNet
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
......
...@@ -21,3 +21,4 @@ from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, Glide ...@@ -21,3 +21,4 @@ from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, Glide
from .unet_grad_tts import UNetGradTTSModel from .unet_grad_tts import UNetGradTTSModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
from .unet_rl import TemporalUNet from .unet_rl import TemporalUNet
from .unet_sde_score_estimation import NCSNpp
...@@ -5,6 +5,7 @@ import math ...@@ -5,6 +5,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
try: try:
import einops import einops
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
...@@ -104,14 +105,14 @@ class ResidualTemporalBlock(nn.Module): ...@@ -104,14 +105,14 @@ class ResidualTemporalBlock(nn.Module):
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
def __init__( def __init__(
self, self,
training_horizon, training_horizon,
transition_dim, transition_dim,
cond_dim, cond_dim,
predict_epsilon=False, predict_epsilon=False,
clip_denoised=True, clip_denoised=True,
dim=32, dim=32,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),
): ):
super().__init__() super().__init__()
...@@ -211,14 +212,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -211,14 +212,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
class TemporalValue(nn.Module): class TemporalValue(nn.Module):
def __init__( def __init__(
self, self,
horizon, horizon,
transition_dim, transition_dim,
cond_dim, cond_dim,
dim=32, dim=32,
time_dim=None, time_dim=None,
out_dim=1, out_dim=1,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),
): ):
super().__init__() super().__init__()
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment