"docs/source/api/python/dgl.DGLGraph.rst" did not exist on "3ee7e11268eb96c44cad3e57d4c8aa870ed2fa02"
Commit 07ffe73f authored by anton-l's avatar anton-l
Browse files

Style

parent bb98a5b7
import torch import torch
from torch import nn from torch import nn
from transformers import CLIPTextConfig, GPT2Tokenizer from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, UNetGLIDEModel
from diffusers import UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel
from modeling_glide import GLIDE from modeling_glide import GLIDE
from transformers import CLIPTextConfig, GPT2Tokenizer
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict = torch.load("base.pt", map_location="cpu") state_dict = torch.load("base.pt", map_location="cpu")
...@@ -22,7 +23,7 @@ config = CLIPTextConfig( ...@@ -22,7 +23,7 @@ config = CLIPTextConfig(
) )
model = CLIPTextModel(config).eval() model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>") tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>")
#tokenizer.save_pretrained("./glide-base") # tokenizer.save_pretrained("./glide-base")
hf_encoder = model.text_model hf_encoder = model.text_model
...@@ -51,11 +52,11 @@ for layer_idx in range(config.num_hidden_layers): ...@@ -51,11 +52,11 @@ for layer_idx in range(config.num_hidden_layers):
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"] hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"] hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
#inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt") # inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
#with torch.no_grad(): # with torch.no_grad():
# outputs = model(**inputs) # outputs = model(**inputs)
#model.save_pretrained("./glide-base") # model.save_pretrained("./glide-base")
### Convert the UNet ### Convert the UNet
...@@ -80,4 +81,4 @@ scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squar ...@@ -80,4 +81,4 @@ scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squar
glide = GLIDE(unet=unet_model, noise_scheduler=scheduler, text_encoder=model, tokenizer=tokenizer) glide = GLIDE(unet=unet_model, noise_scheduler=scheduler, text_encoder=model, tokenizer=tokenizer)
glide.save_pretrained("./glide-base") glide.save_pretrained("./glide-base")
\ No newline at end of file
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
# limitations under the License. # limitations under the License.
from diffusers import DiffusionPipeline, UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel import numpy as np
from transformers import GPT2Tokenizer import torch
import tqdm import tqdm
import torch from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, UNetGLIDEModel
import numpy as np from transformers import GPT2Tokenizer
def _extract_into_tensor(arr, timesteps, broadcast_shape): def _extract_into_tensor(arr, timesteps, broadcast_shape):
...@@ -40,14 +40,16 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): ...@@ -40,14 +40,16 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
class GLIDE(DiffusionPipeline): class GLIDE(DiffusionPipeline):
def __init__( def __init__(
self, self,
unet: UNetGLIDEModel, unet: UNetGLIDEModel,
noise_scheduler: ClassifierFreeGuidanceScheduler, noise_scheduler: ClassifierFreeGuidanceScheduler,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer tokenizer: GPT2Tokenizer,
): ):
super().__init__() super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer) self.register_modules(
unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer
)
def q_posterior_mean_variance(self, x_start, x_t, t): def q_posterior_mean_variance(self, x_start, x_t, t):
""" """
...@@ -129,7 +131,9 @@ class GLIDE(DiffusionPipeline): ...@@ -129,7 +131,9 @@ class GLIDE(DiffusionPipeline):
self.text_encoder.to(torch_device) self.text_encoder.to(torch_device)
# 1. Sample gaussian noise # 1. Sample gaussian noise
image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator) image = self.noise_scheduler.sample_noise(
(1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator
)
# 2. Encode tokens # 2. Encode tokens
# an empty input is needed to guide the model away from ( # an empty input is needed to guide the model away from (
...@@ -141,9 +145,7 @@ class GLIDE(DiffusionPipeline): ...@@ -141,9 +145,7 @@ class GLIDE(DiffusionPipeline):
t = torch.tensor([i] * image.shape[0], device=torch_device) t = torch.tensor([i] * image.shape[0], device=torch_device)
mean, variance, log_variance, pred_xstart = self.p_mean_variance(self.unet, transformer_out, image, t) mean, variance, log_variance, pred_xstart = self.p_mean_variance(self.unet, transformer_out, image, t)
noise = self.noise_scheduler.sample_noise(image.shape) noise = self.noise_scheduler.sample_noise(image.shape)
nonzero_mask = ( nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
(t != 0).float().view(-1, *([1] * (len(image.shape) - 1)))
) # no noise when t == 0
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
return image return image
import torch import torch
from modeling_glide import GLIDE from modeling_glide import GLIDE
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
......
...@@ -5,10 +5,10 @@ ...@@ -5,10 +5,10 @@
__version__ = "0.0.1" __version__ = "0.0.1"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.clip_text_transformer import CLIPTextModel
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_glide import UNetGLIDEModel from .models.unet_glide import UNetGLIDEModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .models.clip_text_transformer import CLIPTextModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
...@@ -89,7 +89,6 @@ class ConfigMixin: ...@@ -89,7 +89,6 @@ class ConfigMixin:
self.to_json_file(output_config_file) self.to_json_file(output_config_file)
logger.info(f"ConfigMixinuration saved in {output_config_file}") logger.info(f"ConfigMixinuration saved in {output_config_file}")
@classmethod @classmethod
def get_config_dict( def get_config_dict(
...@@ -183,7 +182,7 @@ class ConfigMixin: ...@@ -183,7 +182,7 @@ class ConfigMixin:
logger.info(f"loading configuration file {config_file}") logger.info(f"loading configuration file {config_file}")
else: else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}") logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
return config_dict return config_dict
@classmethod @classmethod
...@@ -199,9 +198,8 @@ class ConfigMixin: ...@@ -199,9 +198,8 @@ class ConfigMixin:
# use value from config dict # use value from config dict
init_dict[key] = config_dict.pop(key) init_dict[key] = config_dict.pop(key)
unused_kwargs = config_dict.update(kwargs) unused_kwargs = config_dict.update(kwargs)
passed_keys = set(init_dict.keys()) passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0: if len(expected_keys - passed_keys) > 0:
logger.warn( logger.warn(
...@@ -212,9 +210,7 @@ class ConfigMixin: ...@@ -212,9 +210,7 @@ class ConfigMixin:
@classmethod @classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
config_dict = cls.get_config_dict( config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .clip_text_transformer import CLIPTextModel
from .unet import UNetModel from .unet import UNetModel
from .unet_glide import UNetGLIDEModel from .unet_glide import UNetGLIDEModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
from .clip_text_transformer import CLIPTextModel
...@@ -14,14 +14,15 @@ ...@@ -14,14 +14,15 @@
# limitations under the License. # limitations under the License.
""" PyTorch CLIP model.""" """ PyTorch CLIP model."""
from dataclasses import dataclass
import math import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union from typing import Any, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
...@@ -32,7 +33,7 @@ from transformers.utils import ( ...@@ -32,7 +33,7 @@ from transformers.utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from transformers import CLIPModel, CLIPConfig, CLIPVisionConfig, CLIPTextConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -153,11 +154,11 @@ class CLIPTextEmbeddings(nn.Module): ...@@ -153,11 +154,11 @@ class CLIPTextEmbeddings(nn.Module):
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
...@@ -193,16 +194,15 @@ class CLIPAttention(nn.Module): ...@@ -193,16 +194,15 @@ class CLIPAttention(nn.Module):
) )
self.scale = 1 / math.sqrt(math.sqrt(self.head_dim)) self.scale = 1 / math.sqrt(math.sqrt(self.head_dim))
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim*3) self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -212,9 +212,7 @@ class CLIPAttention(nn.Module): ...@@ -212,9 +212,7 @@ class CLIPAttention(nn.Module):
qkv_states = qkv_states.view(bsz, tgt_len, self.num_heads, -1) qkv_states = qkv_states.view(bsz, tgt_len, self.num_heads, -1)
query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=-1) query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=-1)
attn_weights = torch.einsum( attn_weights = torch.einsum("bthc,bshc->bhts", query_states * self.scale, key_states * self.scale)
"bthc,bshc->bhts", query_states * self.scale, key_states * self.scale
)
wdtype = attn_weights.dtype wdtype = attn_weights.dtype
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).type(wdtype) attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).type(wdtype)
...@@ -252,11 +250,11 @@ class CLIPEncoderLayer(nn.Module): ...@@ -252,11 +250,11 @@ class CLIPEncoderLayer(nn.Module):
self.layer_norm2 = nn.LayerNorm(self.embed_dim) self.layer_norm2 = nn.LayerNorm(self.embed_dim)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor, causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]: ) -> Tuple[torch.FloatTensor]:
""" """
Args: Args:
...@@ -313,19 +311,19 @@ class CLIPPreTrainedModel(PreTrainedModel): ...@@ -313,19 +311,19 @@ class CLIPPreTrainedModel(PreTrainedModel):
module.padding_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) module.padding_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
elif isinstance(module, CLIPVisionEmbeddings): elif isinstance(module, CLIPVisionEmbeddings):
factor = self.config.initializer_factor factor = self.config.initializer_factor
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim ** -0.5 * factor) nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
elif isinstance(module, CLIPAttention): elif isinstance(module, CLIPAttention):
factor = self.config.initializer_factor factor = self.config.initializer_factor
in_proj_std = (module.embed_dim ** -0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
out_proj_std = (module.embed_dim ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor
nn.init.normal_(module.qkv_proj.weight, std=in_proj_std) nn.init.normal_(module.qkv_proj.weight, std=in_proj_std)
nn.init.normal_(module.out_proj.weight, std=out_proj_std) nn.init.normal_(module.out_proj.weight, std=out_proj_std)
elif isinstance(module, CLIPMLP): elif isinstance(module, CLIPMLP):
factor = self.config.initializer_factor factor = self.config.initializer_factor
in_proj_std = ( in_proj_std = (
(module.config.hidden_size ** -0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
) )
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
nn.init.normal_(module.fc1.weight, std=fc_std) nn.init.normal_(module.fc1.weight, std=fc_std)
...@@ -333,11 +331,11 @@ class CLIPPreTrainedModel(PreTrainedModel): ...@@ -333,11 +331,11 @@ class CLIPPreTrainedModel(PreTrainedModel):
elif isinstance(module, CLIPModel): elif isinstance(module, CLIPModel):
nn.init.normal_( nn.init.normal_(
module.text_projection.weight, module.text_projection.weight,
std=module.text_embed_dim ** -0.5 * self.config.initializer_factor, std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
) )
nn.init.normal_( nn.init.normal_(
module.visual_projection.weight, module.visual_projection.weight,
std=module.vision_embed_dim ** -0.5 * self.config.initializer_factor, std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
) )
if isinstance(module, nn.LayerNorm): if isinstance(module, nn.LayerNorm):
...@@ -463,13 +461,13 @@ class CLIPEncoder(nn.Module): ...@@ -463,13 +461,13 @@ class CLIPEncoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward( def forward(
self, self,
inputs_embeds, inputs_embeds,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]: ) -> Union[Tuple, BaseModelOutput]:
r""" r"""
Args: Args:
...@@ -562,13 +560,13 @@ class CLIPTextTransformer(nn.Module): ...@@ -562,13 +560,13 @@ class CLIPTextTransformer(nn.Module):
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]: ) -> Union[Tuple, BaseModelOutputWithPooling]:
r""" r"""
Returns: Returns:
...@@ -652,13 +650,13 @@ class CLIPTextModel(CLIPPreTrainedModel): ...@@ -652,13 +650,13 @@ class CLIPTextModel(CLIPPreTrainedModel):
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]: ) -> Union[Tuple, BaseModelOutputWithPooling]:
r""" r"""
Returns: Returns:
...@@ -684,4 +682,4 @@ class CLIPTextModel(CLIPPreTrainedModel): ...@@ -684,4 +682,4 @@ class CLIPTextModel(CLIPPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
) )
\ No newline at end of file
...@@ -470,7 +470,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -470,7 +470,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
self.channel_mult = channel_mult self.channel_mult = channel_mult
self.conv_resample = conv_resample self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
#self.dtype = torch.float16 if use_fp16 else torch.float32 # self.dtype = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import importlib import importlib
import os import os
from typing import Optional, Union from typing import Optional, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
# CHANGE to diffusers.utils # CHANGE to diffusers.utils
...@@ -64,7 +65,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -64,7 +65,7 @@ class DiffusionPipeline(ConfigMixin):
# set models # set models
setattr(self, name, module) setattr(self, name, module)
register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"} register_dict = {"_module": self.__module__.split(".")[-1] + ".py"}
self.register(**register_dict) self.register(**register_dict)
def save_pretrained(self, save_directory: Union[str, os.PathLike]): def save_pretrained(self, save_directory: Union[str, os.PathLike]):
......
...@@ -16,5 +16,5 @@ ...@@ -16,5 +16,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .gaussian_ddpm import GaussianDDPMScheduler
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .gaussian_ddpm import GaussianDDPMScheduler
...@@ -11,10 +11,11 @@ ...@@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
import math import math
from torch import nn
import numpy as np import numpy as np
import torch
from torch import nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
...@@ -80,19 +81,13 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): ...@@ -80,19 +81,13 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0) # calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = ( self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = np.log( self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1], self.posterior_variance[1:]) np.append(self.posterior_variance[1], self.posterior_variance[1:])
)
self.posterior_mean_coef1 = (
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
) )
self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
def sample_noise(self, shape, device, generator=None): def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic # always sample on CPU to be deterministic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment