"vscode:/vscode.git/clone" did not exist on "a5bdb678c07f3875020c56a7d3001ac7e64c72b2"
Unverified Commit c1971a53 authored by Isamu Isozaki's avatar Isamu Isozaki Committed by GitHub
Browse files

Textual inv save log memory (#2184)



* Quality check and adding tokenizer

* Adapted stable diffusion to mixed precision+finished up style fixes

* Fixed based on patrick's review

* Fixed oom from number of validation images

* Removed unnecessary np.array conversion

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 41db2dbf
......@@ -781,7 +781,10 @@ def main():
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
unet=unet,
vae=vae,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
......@@ -791,8 +794,11 @@ def main():
generator = (
None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
)
prompt = args.num_validation_images * [args.validation_prompt]
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
images = []
for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
images.append(image)
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment