"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "69e8dbb21d66120a72c9e0b076c017dde1a0da74"
Unverified Commit 1d4ad34a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Dreambooth] Make compatible with alt diffusion (#1470)

* [Dreambooth] Make compatible with alt diffusion

* make style

* add example
parent 20ce68f9
...@@ -195,6 +195,17 @@ accelerate launch train_dreambooth.py \ ...@@ -195,6 +195,17 @@ accelerate launch train_dreambooth.py \
--max_train_steps=800 --max_train_steps=800
``` ```
### Using DreamBooth for other pipelines than Stable Diffusion
Altdiffusion also support dreambooth now, the runing comman is basically the same as abouve, all you need to do is replace the `MODEL_NAME` like this:
One can now simply change the `pretrained_model_name_or_path` to another architecture such as [`AltDiffusion`](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion).
```
export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion-m9"
or
export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion"
```
### Inference ### Inference
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt. Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
......
...@@ -14,18 +14,38 @@ from torch.utils.data import Dataset ...@@ -14,18 +14,38 @@ from torch.utils.data import Dataset
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer from transformers import AutoTokenizer, PretrainedConfig
logger = get_logger(__name__) logger = get_logger(__name__)
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
else:
raise ValueError(f"{model_class} is not supported.")
def parse_args(input_args=None): def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument( parser.add_argument(
...@@ -357,7 +377,7 @@ def main(args): ...@@ -357,7 +377,7 @@ def main(args):
if cur_class_images < args.num_class_images: if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
safety_checker=None, safety_checker=None,
...@@ -407,19 +427,24 @@ def main(args): ...@@ -407,19 +427,24 @@ def main(args):
# Load the tokenizer # Load the tokenizer
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name, args.tokenizer_name,
revision=args.revision, revision=args.revision,
use_fast=False,
) )
elif args.pretrained_model_name_or_path: elif args.pretrained_model_name_or_path:
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
subfolder="tokenizer", subfolder="tokenizer",
revision=args.revision, revision=args.revision,
use_fast=False,
) )
# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)
# Load models and create wrapper for stable diffusion # Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained( text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
subfolder="text_encoder", subfolder="text_encoder",
revision=args.revision, revision=args.revision,
...@@ -649,7 +674,7 @@ def main(args): ...@@ -649,7 +674,7 @@ def main(args):
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
if accelerator.is_main_process: if accelerator.is_main_process:
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder), text_encoder=accelerator.unwrap_model(text_encoder),
...@@ -669,7 +694,7 @@ def main(args): ...@@ -669,7 +694,7 @@ def main(args):
# Create the pipeline using using the trained modules and save it. # Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process: if accelerator.is_main_process:
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder), text_encoder=accelerator.unwrap_model(text_encoder),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment