"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3dc97bd1482fb099aa41a15aae55ff45c8f2b042"
Commit 49718b47 authored by Nathan Lambert's avatar Nathan Lambert
Browse files

add imports for RL UNet

parent 9c96682a
...@@ -9,6 +9,7 @@ from .models.unet import UNetModel ...@@ -9,6 +9,7 @@ from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .models.unet_grad_tts import UNetGradTTSModel from .models.unet_grad_tts import UNetGradTTSModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .models.unet_rl import TemporalUNet
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
......
...@@ -20,3 +20,4 @@ from .unet import UNetModel ...@@ -20,3 +20,4 @@ from .unet import UNetModel
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .unet_grad_tts import UNetGradTTSModel from .unet_grad_tts import UNetGradTTSModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
from .unet_rl import TemporalUNet
\ No newline at end of file
...@@ -6,6 +6,10 @@ import einops ...@@ -6,6 +6,10 @@ import einops
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
import math import math
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
...@@ -85,7 +89,7 @@ class ResidualTemporalBlock(nn.Module): ...@@ -85,7 +89,7 @@ class ResidualTemporalBlock(nn.Module):
out = self.blocks[1](out) out = self.blocks[1](out)
return out + self.residual_conv(x) return out + self.residual_conv(x)
class TemporalUnet(nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
def __init__( def __init__(
self, self,
...@@ -99,7 +103,7 @@ class TemporalUnet(nn.Module): ...@@ -99,7 +103,7 @@ class TemporalUnet(nn.Module):
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))
print(f'[ models/temporal ] Channel dimensions: {in_out}') # print(f'[ models/temporal ] Channel dimensions: {in_out}')
time_dim = dim time_dim = dim
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment