Commit ac796924 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add score estimation model

parent bd9c9fbf
......@@ -7,9 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__ = "0.0.4"
from .modeling_utils import ModelMixin
from .models.unet import UNetModel
from .models.unet_ldm import UNetLDMModel
from .models.unet_rl import TemporalUNet
from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
......
......@@ -21,3 +21,4 @@ from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, Glide
from .unet_grad_tts import UNetGradTTSModel
from .unet_ldm import UNetLDMModel
from .unet_rl import TemporalUNet
from .unet_sde_score_estimation import NCSNpp
......@@ -5,6 +5,7 @@ import math
import torch
import torch.nn as nn
try:
import einops
from einops.layers.torch import Rearrange
......@@ -104,14 +105,14 @@ class ResidualTemporalBlock(nn.Module):
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
def __init__(
self,
training_horizon,
transition_dim,
cond_dim,
predict_epsilon=False,
clip_denoised=True,
dim=32,
dim_mults=(1, 2, 4, 8),
self,
training_horizon,
transition_dim,
cond_dim,
predict_epsilon=False,
clip_denoised=True,
dim=32,
dim_mults=(1, 2, 4, 8),
):
super().__init__()
......@@ -211,14 +212,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
class TemporalValue(nn.Module):
def __init__(
self,
horizon,
transition_dim,
cond_dim,
dim=32,
time_dim=None,
out_dim=1,
dim_mults=(1, 2, 4, 8),
self,
horizon,
transition_dim,
cond_dim,
dim=32,
time_dim=None,
out_dim=1,
dim_mults=(1, 2, 4, 8),
):
super().__init__()
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment