Unverified Commit d03c9099 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Wuerstchen] text to image training script (#5052)



* initial script

* formatting

* prior trainer wip

* add efficient_net_encoder

* add CLIPTextModel

* add prior ema support

* optimizer

* fix typo

* add dataloader

* prompt_embeds and image_embeds

* intial training loop

* fix output_dir

* fix add_noise

* accelerator check

* make effnet_transforms dynamic

* fix training loop

* add validation logging

* use loaded text_encoder

* use PreTrainedTokenizerFast

* load weigth from pickle

* save_model_card

* remove unused file

* fix typos

* save prior pipeilne in its own folder

* fix imports

* fix pipe_t2i

* scale image_embeds

* remove snr_gamma

* format

* initial lora prior training

* log_validation and save

* initial gradient working

* remove save/load hooks

* set set_attn_processor on prior_prior

* add lora script

* typos

* use LoraLoaderMixin for prior pipeline

* fix usage

* make fix-copies

* yse repo_id

* write_lora_layers is a staitcmethod

* use defualts

* fix defaults

* undo

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py

* Update src/diffusers/loaders.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* add graident checkpoint support to prior

* gradient_checkpointing

* formatting

* Update examples/wuerstchen/text_to_image/README.md
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/README.md
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/README.md
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/README.md
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/README.md
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/train_text_to_image_prior.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* use default unet and text_encoder

* fix test

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 93df5bb6
# Würstchen text-to-image fine-tuning
## Running locally with PyTorch
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```
Then cd into the example folder and run
```bash
cd examples/wuerstchen/text_to_image
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To log in, run:
```bash
huggingface-cli login
```
## Prior training
You can fine-tune the Würstchen prior model with the `train_text_to_image_prior.py` script. Note that we currently support `--gradient_checkpointing` for prior model fine-tuning so you can use it for more GPU memory constrained setups.
<br>
<!-- accelerate_snippet_start -->
```bash
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
accelerate launch train_text_to_image_prior.py \
--mixed_precision="fp16" \
--dataset_name=$DATASET_NAME \
--resolution=768 \
--train_batch_size=4 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--dataloader_num_workers=4 \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--checkpoints_total_limit=3 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--validation_prompts="A robot pokemon, 4k photo" \
--report_to="wandb" \
--push_to_hub \
--output_dir="wuerstchen-prior-pokemon-model"
```
<!-- accelerate_snippet_end -->
## Training with LoRA
Low-Rank Adaption of Large Language Models (or LoRA) was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
### Prior Training
First, you need to set up your development environment as explained in the [installation](#Running-locally-with-PyTorch) section. Make sure to set the `DATASET_NAME` environment variable. Here, we will use the [Pokemon captions dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
```bash
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
accelerate launch train_text_to_image_prior_lora.py \
--mixed_precision="fp16" \
--dataset_name=$DATASET_NAME --caption_column="text" \
--resolution=768 \
--train_batch_size=8 \
--num_train_epochs=100 --checkpointing_steps=5000 \
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
--seed=42 \
--rank=4 \
--validation_prompt="cute dragon creature" \
--report_to="wandb" \
--push_to_hub \
--output_dir="wuerstchen-prior-pokemon-lora"
```
import torch.nn as nn
from torchvision.models import efficientnet_v2_l, efficientnet_v2_s
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
class EfficientNetEncoder(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, c_latent=16, c_cond=1280, effnet="efficientnet_v2_s"):
super().__init__()
if effnet == "efficientnet_v2_s":
self.backbone = efficientnet_v2_s(weights="DEFAULT").features
else:
self.backbone = efficientnet_v2_l(weights="DEFAULT").features
self.mapper = nn.Sequential(
nn.Conv2d(c_cond, c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
)
def forward(self, x):
return self.mapper(self.backbone(x))
accelerate>=0.16.0
torchvision
transformers>=4.25.1
wandb
huggingface-cli
bitsandbytes
deepspeed
This diff is collapsed.
...@@ -1208,7 +1208,7 @@ class LoraLoaderMixin: ...@@ -1208,7 +1208,7 @@ class LoraLoaderMixin:
self.load_lora_into_unet( self.load_lora_into_unet(
state_dict, state_dict,
network_alphas=network_alphas, network_alphas=network_alphas,
unet=self.unet, unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name, adapter_name=adapter_name,
_pipeline=self, _pipeline=self,
...@@ -1216,7 +1216,9 @@ class LoraLoaderMixin: ...@@ -1216,7 +1216,9 @@ class LoraLoaderMixin:
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
state_dict, state_dict,
network_alphas=network_alphas, network_alphas=network_alphas,
text_encoder=self.text_encoder, text_encoder=getattr(self, self.text_encoder_name)
if not hasattr(self, "text_encoder")
else self.text_encoder,
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name, adapter_name=adapter_name,
...@@ -1577,7 +1579,7 @@ class LoraLoaderMixin: ...@@ -1577,7 +1579,7 @@ class LoraLoaderMixin:
""" """
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes. # their prefixes.
keys = list(state_dict.keys()) keys = list(state_dict.keys())
...@@ -1961,7 +1963,7 @@ class LoraLoaderMixin: ...@@ -1961,7 +1963,7 @@ class LoraLoaderMixin:
@classmethod @classmethod
def save_lora_weights( def save_lora_weights(
self, cls,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
...@@ -2001,7 +2003,7 @@ class LoraLoaderMixin: ...@@ -2001,7 +2003,7 @@ class LoraLoaderMixin:
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
) )
unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()} unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()}
state_dict.update(unet_lora_state_dict) state_dict.update(unet_lora_state_dict)
if text_encoder_lora_layers is not None: if text_encoder_lora_layers is not None:
...@@ -2012,12 +2014,12 @@ class LoraLoaderMixin: ...@@ -2012,12 +2014,12 @@ class LoraLoaderMixin:
) )
text_encoder_lora_state_dict = { text_encoder_lora_state_dict = {
f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items() f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
} }
state_dict.update(text_encoder_lora_state_dict) state_dict.update(text_encoder_lora_state_dict)
# Save the model # Save the model
self.write_lora_layers( cls.write_lora_layers(
state_dict=state_dict, state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
is_main_process=is_main_process, is_main_process=is_main_process,
...@@ -2026,6 +2028,7 @@ class LoraLoaderMixin: ...@@ -2026,6 +2028,7 @@ class LoraLoaderMixin:
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
@staticmethod
def write_lora_layers( def write_lora_layers(
state_dict: Dict[str, torch.Tensor], state_dict: Dict[str, torch.Tensor],
save_directory: str, save_directory: str,
...@@ -3248,7 +3251,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -3248,7 +3251,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
@classmethod @classmethod
def save_lora_weights( def save_lora_weights(
self, cls,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
...@@ -3299,7 +3302,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -3299,7 +3302,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers( cls.write_lora_layers(
state_dict=state_dict, state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
is_main_process=is_main_process, is_main_process=is_main_process,
......
...@@ -14,16 +14,29 @@ ...@@ -14,16 +14,29 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Dict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...utils import is_torch_version
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
class WuerstchenPrior(ModelMixin, ConfigMixin): class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
unet_name = "prior"
_supports_gradient_checkpointing = True
@register_to_config @register_to_config
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
super().__init__() super().__init__()
...@@ -45,6 +58,90 @@ class WuerstchenPrior(ModelMixin, ConfigMixin): ...@@ -45,6 +58,90 @@ class WuerstchenPrior(ModelMixin, ConfigMixin):
nn.Conv2d(c, c_in * 2, kernel_size=1), nn.Conv2d(c, c_in * 2, kernel_size=1),
) )
self.gradient_checkpointing = False
self.set_default_attn_processor()
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True)
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def gen_r_embedding(self, r, max_positions=10000): def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions r = r * max_positions
half_dim = self.c_r // 2 half_dim = self.c_r // 2
...@@ -61,12 +158,42 @@ class WuerstchenPrior(ModelMixin, ConfigMixin): ...@@ -61,12 +158,42 @@ class WuerstchenPrior(ModelMixin, ConfigMixin):
x = self.projection(x) x = self.projection(x)
c_embed = self.cond_mapper(c) c_embed = self.cond_mapper(c)
r_embed = self.gen_r_embedding(r) r_embed = self.gen_r_embedding(r)
for block in self.blocks:
if isinstance(block, AttnBlock): if self.training and self.gradient_checkpointing:
x = block(x, c_embed)
elif isinstance(block, TimestepBlock): def create_custom_forward(module):
x = block(x, r_embed) def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
for block in self.blocks:
if isinstance(block, AttnBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, c_embed, use_reentrant=False
)
elif isinstance(block, TimestepBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, r_embed, use_reentrant=False
)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
else: else:
x = block(x) for block in self.blocks:
if isinstance(block, AttnBlock):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed)
elif isinstance(block, TimestepBlock):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x)
else:
for block in self.blocks:
if isinstance(block, AttnBlock):
x = block(x, c_embed)
elif isinstance(block, TimestepBlock):
x = block(x, r_embed)
else:
x = block(x)
a, b = self.out(x).chunk(2, dim=1) a, b = self.out(x).chunk(2, dim=1)
return (x_in - a) / ((1 - b).abs() + 1e-5) return (x_in - a) / ((1 - b).abs() + 1e-5)
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import LoraLoaderMixin
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import ( from ...utils import (
BaseOutput, BaseOutput,
...@@ -65,7 +66,7 @@ class WuerstchenPriorPipelineOutput(BaseOutput): ...@@ -65,7 +66,7 @@ class WuerstchenPriorPipelineOutput(BaseOutput):
image_embeddings: Union[torch.FloatTensor, np.ndarray] image_embeddings: Union[torch.FloatTensor, np.ndarray]
class WuerstchenPriorPipeline(DiffusionPipeline): class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
""" """
Pipeline for generating image prior for Wuerstchen. Pipeline for generating image prior for Wuerstchen.
...@@ -90,6 +91,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline): ...@@ -90,6 +91,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
Default resolution for multiple images generated. Default resolution for multiple images generated.
""" """
unet_name = "prior"
text_encoder_name = "text_encoder"
model_cpu_offload_seq = "text_encoder->prior" model_cpu_offload_seq = "text_encoder->prior"
def __init__( def __init__(
......
...@@ -211,24 +211,15 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin): ...@@ -211,24 +211,15 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
noise: torch.FloatTensor, noise: torch.FloatTensor,
timesteps: torch.IntTensor, timesteps: torch.FloatTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples device = original_samples.device
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) dtype = original_samples.dtype
timesteps = timesteps.to(original_samples.device) alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
timesteps.size(0), *[1 for _ in original_samples.shape[1:]]
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 )
sqrt_alpha_prod = sqrt_alpha_prod.flatten() noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): return noisy_samples.to(dtype=dtype)
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self): def __len__(self):
return self.config.num_train_timesteps return self.config.num_train_timesteps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment