"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "187ea539aed54872675cd27f6a58b27084a173ab"
Commit ab946575 authored by patil-suraj's avatar patil-suraj
Browse files

add conversion script for BDDMPipeline

parent a1b5ef5d
import argparse
import torch
from diffusers.pipelines.bddm import DiffWave, BDDMPipeline
from diffusers import DDPMScheduler
def convert_bddm_orginal(checkpoint_path, noise_scheduler_checkpoint_path, output_path):
sd = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
noise_scheduler_sd = torch.load(noise_scheduler_checkpoint_path, map_location="cpu")
model = DiffWave()
model.load_state_dict(sd, strict=False)
ts, _, betas, _ = noise_scheduler_sd
ts, betas = list(ts.numpy().tolist()), list(betas.numpy().tolist())
noise_scheduler = DDPMScheduler(
timesteps=12,
trained_betas=betas,
timestep_values=ts,
clip_sample=False,
tensor_format="np",
)
pipeline = BDDMPipeline(model, noise_scheduler)
pipeline.save_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, required=True)
parser.add_argument("--noise_scheduler_checkpoint_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
args = parser.parse_args()
convert_bddm_orginal(args.checkpoint_path, args.noise_scheduler_checkpoint_path, args.output_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment