"src/vscode:/vscode.git/clone" did not exist on "33d7e89c42e0fe7b4a277d7a5bae12ba14828dd8"
Commit 5b1af9ab authored by Patrick von Platen's avatar Patrick von Platen
Browse files

correct naming in glide

parent 0a1d4c58
...@@ -29,7 +29,7 @@ from torchvision import transforms, utils ...@@ -29,7 +29,7 @@ from torchvision import transforms, utils
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixinMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
...@@ -175,7 +175,7 @@ class AttnBlock(nn.Module): ...@@ -175,7 +175,7 @@ class AttnBlock(nn.Module):
return x + h_ return x + h_
class UNetModel(ModelMixin, ConfigMixin): class UNetModel(ModelMixin, ConfigMixinMixin):
def __init__( def __init__(
self, self,
ch=128, ch=128,
......
...@@ -5,8 +5,8 @@ import torch ...@@ -5,8 +5,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..configuration_utils import Config from ..configuration_utils import ConfigMixin
from ..modeling_utils import PreTrainedModel from ..modeling_utils import ModelMixin
def convert_module_to_f16(l): def convert_module_to_f16(l):
...@@ -388,7 +388,7 @@ class QKVAttention(nn.Module): ...@@ -388,7 +388,7 @@ class QKVAttention(nn.Module):
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
class UNetGLIDEModel(PreTrainedModel, Config): class UNetGLIDEModel(ModelMixin, ConfigMixin):
""" """
The full UNet model with attention and timestep embedding. The full UNet model with attention and timestep embedding.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment