Unverified Commit b53924c7 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Merge pull request #6 from huggingface/add-ldm

add unet ldm in init
parents ee71a3b6 4d53a521
...@@ -7,5 +7,6 @@ __version__ = "0.0.1" ...@@ -7,5 +7,6 @@ __version__ = "0.0.1"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_glide import UNetGLIDEModel from .models.unet_glide import UNetGLIDEModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
...@@ -18,3 +18,4 @@ ...@@ -18,3 +18,4 @@
from .unet import UNetModel from .unet import UNetModel
from .unet_glide import UNetGLIDEModel from .unet_glide import UNetGLIDEModel
from .unet_ldm import UNetLDMModel
...@@ -830,7 +830,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -830,7 +830,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self.conv_resample = conv_resample self.conv_resample = conv_resample
self.num_classes = num_classes self.num_classes = num_classes
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32 self.dtype_ = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample
...@@ -1060,7 +1060,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -1060,7 +1060,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
assert y.shape == (x.shape[0],) assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)
h = x.type(self.dtype) h = x.type(self.dtype_)
for module in self.input_blocks: for module in self.input_blocks:
h = module(h, emb, context) h = module(h, emb, context)
hs.append(h) hs.append(h)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment